Repository: PacktPublishing/Hands-On-Dependency-Injection-in-Go Branch: master Commit: de70ca622e41 Files: 656 Total size: 1.5 MB Directory structure: gitextract_t7lucnow/ ├── .gitignore ├── LICENSE ├── README.md ├── ch01/ │ ├── 01_defining_depenency_injection/ │ │ ├── 01_interface.go │ │ ├── 02_function_literal.go │ │ ├── 03_test_without_nfs_test.go │ │ └── 04_fail_test_without_nfs_test.go │ └── 02_code_smells/ │ ├── 01_code_bloat/ │ │ └── 01_switch_type.go │ ├── 02_resistance_to_change/ │ │ └── 01_shotgun_surgey.go │ ├── 03_wasted_effort/ │ │ ├── 01_excessive_comments.go │ │ └── 02_complicated_go.go │ └── 04_tight_coupling/ │ ├── 01_circular_dependencies/ │ │ ├── config/ │ │ │ └── config.go │ │ └── payment/ │ │ └── currency.go │ ├── 02_object_orgy.go │ └── 03_feature_envy.go ├── ch02/ │ ├── 01_single_responsibility_principle/ │ │ ├── 01_responsibility_vs_change.go │ │ ├── 02_responsibility_vs_change.go │ │ ├── 03_responsibility_vs_change.go │ │ ├── 04_long_method.go │ │ ├── 04_long_method_test.go │ │ └── 05_srp_method.go │ ├── 02_open_closed_principle/ │ │ ├── 01_open_closed_failure.go │ │ ├── 02_open_closed_success.go │ │ ├── 03_shotgun_surgery.go │ │ ├── 04_after_shotgun_surgery.go │ │ ├── 05_composition.go │ │ ├── 06_handler_struct.go │ │ └── 07_handler_func.go │ ├── 03_liskov_substitution_principle/ │ │ ├── 01_violation/ │ │ │ └── example.go │ │ ├── 02_fixed/ │ │ │ └── example.go │ │ ├── 03_fixed/ │ │ │ └── example.go │ │ ├── 04_behaviour.go │ │ └── 05_behaviour_fixed.go │ └── 04_interface_segregation_principle/ │ ├── 01_fat_interface.go │ ├── 02_thin_interface.go │ ├── 03_repeated_inputs.go │ ├── 04_repeated_inputs.go │ ├── 05_repeated_inputs.go │ └── 06_implicit_interfaces.go ├── ch03/ │ ├── 01_optimizing_for_humans/ │ │ ├── 01_not_so_simple.go │ │ ├── 02_start_simple.go │ │ ├── 03_too_abstract.go │ │ ├── 04_common_concept.go │ │ ├── 05_boolean_param.go │ │ ├── 06_hidden_boolean.go │ │ ├── 07_wide_formatter.go │ │ ├── 08_thin_formatters.go │ │ └── 09_extra_config.go │ ├── 02_unit_tests/ │ │ ├── 01_loader.go │ │ ├── 02_language_feature.go │ │ ├── 03_simple_test.go │ │ ├── 04_test_from_api.go │ │ ├── 05_repeated_code.go │ │ ├── 06_tdt.go │ │ ├── 07_person_loader.go │ │ ├── 08_stub.go │ │ ├── 09_stub_tdt.go │ │ └── 10_mocks.go │ ├── 03_test_induced_damage/ │ │ ├── 01_io_closer.go │ │ └── 02_json.go │ ├── 04_visualizing_dependencies/ │ │ └── depgraph.sh │ └── fake.go ├── ch04/ │ ├── 01_welcome/ │ │ ├── 01_bad_names.go │ │ ├── 02_improved_names.go │ │ ├── 03_long_method.go │ │ ├── 04_long_method_test.go │ │ └── 05_short_methods.go │ ├── 03_known_issues/ │ │ ├── 01_data_and_rest/ │ │ │ └── get_example.go │ │ └── 02_config_coupling/ │ │ ├── config.go │ │ └── currency/ │ │ └── currency.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── data.go │ │ │ │ │ └── data_test.go │ │ │ │ ├── exchange/ │ │ │ │ │ └── converter.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ └── go_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ └── list_test.go │ │ │ │ └── register/ │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── common_test.go │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ └── main.go │ └── fake.go ├── ch05/ │ ├── 02_advantages/ │ │ ├── 01_function.go │ │ ├── 02_monkey_patched.go │ │ ├── 03_injected_lambda.go │ │ ├── 04_as_object.go │ │ ├── 05_math_rand.go │ │ └── 06_math_rand_test.go │ ├── 03_applying/ │ │ ├── 01_simple_sqlmock_test.go │ │ └── 02_load.go │ ├── 04_disadvantages/ │ │ ├── 01_verbose.go │ │ ├── 02_verbose_test.go │ │ └── 03_refactored_test.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── data.go │ │ │ │ │ └── data_test.go │ │ │ │ ├── exchange/ │ │ │ │ │ └── converter.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ └── go_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ └── list_test.go │ │ │ │ └── register/ │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── common_test.go │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ └── main.go │ └── fake.go ├── ch06/ │ ├── 01_constructor_injection/ │ │ ├── 01_welcome_email.go │ │ ├── 01_welcome_email_test.go │ │ ├── 02_mailer_interface.go │ │ ├── 03_sender_interface.go │ │ └── 05_duck_typing.go │ ├── 02_advantages/ │ │ ├── 01_easy_to_implement.go │ │ ├── 01_easy_to_implement_example_test.go │ │ ├── 02_easy_to_implement.go │ │ ├── 02_easy_to_implement_example_test.go │ │ ├── 03_predictable.go │ │ ├── 04_predictable.go │ │ ├── 05_encapsulation.go │ │ └── 06_encapsulation.go │ ├── 03_applying/ │ │ ├── 01/ │ │ │ ├── 01_register_handler_before.go │ │ │ ├── data/ │ │ │ │ └── person.go │ │ │ └── register/ │ │ │ └── register.go │ │ ├── 02/ │ │ │ ├── 01_register_handler.go │ │ │ ├── data/ │ │ │ │ └── person.go │ │ │ └── register/ │ │ │ └── register.go │ │ ├── 03/ │ │ │ ├── data/ │ │ │ │ └── person.go │ │ │ ├── mock_register_model_test.go │ │ │ └── register_test.go │ │ ├── 04/ │ │ │ ├── data/ │ │ │ │ └── person.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── register.go │ │ │ └── register_test.go │ │ └── 05/ │ │ ├── data/ │ │ │ └── person.go │ │ ├── fakes.go │ │ ├── get/ │ │ │ └── getter.go │ │ ├── list/ │ │ │ └── lister.go │ │ ├── register/ │ │ │ └── registerer.go │ │ └── server.go │ ├── 04_disadvantages/ │ │ ├── 01_lots_of_changes.go │ │ ├── 02_overuse.go │ │ ├── 03_non_obvious.go │ │ ├── 04_non_obvious_example_test.go │ │ └── 05_constructors.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── data.go │ │ │ │ │ └── data_test.go │ │ │ │ ├── exchange/ │ │ │ │ │ └── converter.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ └── go_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ └── list_test.go │ │ │ │ └── register/ │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── mock_get_model_test.go │ │ │ ├── mock_list_model_test.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ └── main.go │ ├── fake.go │ └── pcov-html ├── ch07/ │ ├── 01_method_injection/ │ │ ├── 01_fprint.go │ │ ├── 02_http_request.go │ │ ├── 03_fprint.go │ │ ├── 04_http_request.go │ │ ├── 05_timestamp_writer_v1.go │ │ ├── 06_timestamp_writer_v2.go │ │ └── 07_timestamp_writer_v3.go │ ├── 02_advantages/ │ │ ├── 01_handler_v1.go │ │ ├── 02_handler_v2.go │ │ ├── 03_handler_v3.go │ │ ├── 04_context_influence.go │ │ └── 05_person_loader.go │ ├── 04_disadvantages/ │ │ ├── 01_data_struct.go │ │ ├── 02_ux_improvement.go │ │ ├── 03_many_params.go │ │ └── 04_many_params_v2.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── data.go │ │ │ │ │ └── data_test.go │ │ │ │ ├── exchange/ │ │ │ │ │ └── converter.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ └── go_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ └── list_test.go │ │ │ │ └── register/ │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── mock_get_model_test.go │ │ │ ├── mock_list_model_test.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ └── main.go │ └── fake.go ├── ch08/ │ ├── 01_config_injection/ │ │ ├── 01_long_constructor.go │ │ ├── 02_by_config_example.go │ │ └── 03_shared_params.go │ ├── 02_advantages/ │ │ ├── 01_injected_config/ │ │ │ ├── 01.go │ │ │ └── 01_test.go │ │ ├── 02_config_injection/ │ │ │ ├── 02.go │ │ │ └── 02_test.go │ │ ├── 03_long_constructor.go │ │ ├── 04_by_config_example.go │ │ ├── config/ │ │ │ └── config.go │ │ ├── logging/ │ │ │ └── logger.go │ │ └── stats/ │ │ └── stats.go │ ├── 03_applying/ │ │ ├── 01_define_register_config.go │ │ ├── 02_register_with_config_injection.go │ │ ├── 03_model_before_data_changes.go │ │ ├── 04_test_config_link_to_config_package.go │ │ ├── 05_result_payload.json │ │ └── 06_simple_test_server.go │ ├── 04_disadvantages/ │ │ ├── 01_leaking_details.go │ │ ├── 02_hiding_details.go │ │ ├── 03_unclear_lifecycle.go │ │ ├── 04_clear_lifecycle.go │ │ └── 05_layers.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── data.go │ │ │ │ │ └── data_test.go │ │ │ │ ├── exchange/ │ │ │ │ │ ├── converter.go │ │ │ │ │ ├── converter_ext_bounday_test.go │ │ │ │ │ └── converter_int_bounday_test.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ └── go_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ └── list_test.go │ │ │ │ └── register/ │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── mock_get_model_test.go │ │ │ ├── mock_list_model_test.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ └── main.go │ └── fake.go ├── ch09/ │ ├── 01_jit_injection/ │ │ ├── 01_injecting_db.go │ │ ├── 01_injecting_db_test.go │ │ ├── 02_injecting_business_logic.go │ │ ├── 03_injecting_db_jit.go │ │ ├── 03_injecting_db_jit_test.go │ │ └── 04_noop_debugger.go │ ├── 02_advantages/ │ │ ├── 01_long_constructor.go │ │ ├── 02_short_constructor.go │ │ ├── 03_optional_dep_without_jitdi.go │ │ ├── 04_optional_dep_with_jitdi.go │ │ ├── 05_loader.go │ │ ├── 06_global_variable/ │ │ │ └── 06_global_variable.go │ │ ├── 07_global_variable_jit/ │ │ │ ├── 07_global_variable_jit.go │ │ │ └── 07_global_variable_jit_test.go │ │ ├── 08_car_v1.go │ │ └── 09_car_v2.go │ ├── 03_applying/ │ │ ├── 01_commands.sh │ │ ├── 02_coverage.txt │ │ └── 03_initial_dao.go │ ├── 04_disadvantages/ │ │ ├── 01_uncertain_init_state.go │ │ ├── 02_certain_init_state.go │ │ ├── 03_cpool_slow_constructor.go │ │ └── 04_get_pool_with_once.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── dao.go │ │ │ │ │ ├── data.go │ │ │ │ │ ├── data_test.go │ │ │ │ │ └── tracker.go │ │ │ │ ├── exchange/ │ │ │ │ │ ├── converter.go │ │ │ │ │ ├── converter_ext_bounday_test.go │ │ │ │ │ └── converter_int_bounday_test.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ ├── go_test.go │ │ │ │ │ └── mock_my_loader_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ ├── list_test.go │ │ │ │ │ └── mock_my_loader_test.go │ │ │ │ └── register/ │ │ │ │ ├── mock_my_saver_test.go │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── mock_get_model_test.go │ │ │ ├── mock_list_model_test.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ └── main.go │ └── fake.go ├── ch10/ │ ├── 01_intro_to_wire/ │ │ ├── 01_simple/ │ │ │ ├── main.go │ │ │ ├── wire.go │ │ │ └── wire_gen.go │ │ ├── 02_params/ │ │ │ ├── main.go │ │ │ ├── wire.go │ │ │ └── wire_gen.go │ │ ├── 03_error/ │ │ │ ├── main.go │ │ │ ├── wire.go │ │ │ └── wire_gen.go │ │ └── 04_without_pset/ │ │ ├── main.go │ │ ├── wire.go │ │ └── wire_gen.go │ ├── 02_advantages/ │ │ ├── 01_dig/ │ │ │ └── main.go │ │ └── 02_instantiation_order/ │ │ ├── handler.go │ │ ├── injectors.go │ │ ├── main.go │ │ ├── model.go │ │ ├── providers.go │ │ └── wire_gen.go │ ├── 03_applying/ │ │ ├── 01_before_config/ │ │ │ └── main.go │ │ ├── 02_after_config/ │ │ │ ├── main.go │ │ │ ├── wire.go │ │ │ └── wire_gen.go │ │ ├── 03_after_exchange/ │ │ │ ├── main.go │ │ │ └── wire.go │ │ ├── 04_after_model/ │ │ │ ├── main.go │ │ │ └── wire.go │ │ ├── 05_after_rest/ │ │ │ ├── main.go │ │ │ ├── wire.go │ │ │ └── wire_gen.go │ │ ├── 06_build_tag.go │ │ ├── 06_build_tag_inverse.go │ │ └── 06_main.go │ ├── 04_disadvantages/ │ │ └── 01_complexity/ │ │ └── main.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── dao.go │ │ │ │ │ ├── data.go │ │ │ │ │ ├── data_test.go │ │ │ │ │ └── tracker.go │ │ │ │ ├── exchange/ │ │ │ │ │ ├── converter.go │ │ │ │ │ ├── converter_ext_bounday_test.go │ │ │ │ │ └── converter_int_bounday_test.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ ├── go_test.go │ │ │ │ │ └── mock_my_loader_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ ├── list_test.go │ │ │ │ │ └── mock_my_loader_test.go │ │ │ │ └── register/ │ │ │ │ ├── mock_my_saver_test.go │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── mock_get_model_test.go │ │ │ ├── mock_list_model_test.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ ├── main.go │ │ ├── main_test.go │ │ ├── wire.go │ │ └── wire_gen.go │ └── fake.go ├── ch11/ │ ├── 01_di_induced_damage/ │ │ ├── 01_long_param/ │ │ │ └── 01_long_param.go │ │ ├── 02_long_param/ │ │ │ └── 01_long_param.go │ │ ├── 03_long_param/ │ │ │ └── 01_long_param.go │ │ ├── 04_long_param/ │ │ │ ├── 01_long_param.go │ │ │ └── 01_long_param_test.go │ │ ├── 05_inject_sql/ │ │ │ └── 01_interface.go │ │ ├── 06_inject_sql/ │ │ │ ├── 01_interface.go │ │ │ ├── 02_implementation.go │ │ │ ├── 02_implementation_test.go │ │ │ ├── dao.go │ │ │ ├── data.go │ │ │ └── data_test.go │ │ ├── 07_needless_indirection/ │ │ │ └── example_test.go │ │ ├── 08_needless_indirection/ │ │ │ ├── 01_mux.go │ │ │ ├── 01_mux_test.go │ │ │ └── mock_my_mux_test.go │ │ ├── 09_needless_indirection/ │ │ │ ├── 01_mux.go │ │ │ └── 01_mux_test.go │ │ ├── 10_needless_indirection/ │ │ │ ├── 01_mux_e2e.go │ │ │ └── 01_mux_e2e_test.go │ │ └── 11_service_locator/ │ │ ├── 01_service_locator.go │ │ └── 02_usage.go │ ├── 02_premature_future/ │ │ └── get.go │ ├── 03_mocking_http_requests/ │ │ ├── converter.go │ │ ├── converter_test.go │ │ └── mock_requester_test.go │ └── acme/ │ ├── internal/ │ │ ├── config/ │ │ │ ├── config.go │ │ │ └── config_test.go │ │ ├── logging/ │ │ │ └── logging.go │ │ ├── modules/ │ │ │ ├── data/ │ │ │ │ ├── dao.go │ │ │ │ ├── data.go │ │ │ │ ├── data_test.go │ │ │ │ └── tracker.go │ │ │ ├── exchange/ │ │ │ │ ├── converter.go │ │ │ │ ├── converter_ext_bounday_test.go │ │ │ │ └── converter_int_bounday_test.go │ │ │ ├── get/ │ │ │ │ ├── get.go │ │ │ │ ├── go_test.go │ │ │ │ └── mock_my_loader_test.go │ │ │ ├── list/ │ │ │ │ ├── list.go │ │ │ │ ├── list_test.go │ │ │ │ └── mock_my_loader_test.go │ │ │ └── register/ │ │ │ ├── mock_exchanger_test.go │ │ │ ├── mock_my_saver_test.go │ │ │ ├── register.go │ │ │ └── register_test.go │ │ └── rest/ │ │ ├── get.go │ │ ├── get_test.go │ │ ├── list.go │ │ ├── list_test.go │ │ ├── mock_get_model_test.go │ │ ├── mock_list_model_test.go │ │ ├── mock_register_model_test.go │ │ ├── not_found.go │ │ ├── not_found_test.go │ │ ├── register.go │ │ ├── register_test.go │ │ └── server.go │ ├── main.go │ ├── main_test.go │ ├── wire.go │ └── wire_gen.go ├── ch12/ │ ├── 01_improvements/ │ │ └── 01_test_logging_test.go │ ├── 03_testing/ │ │ ├── 01_mock_get_model.go │ │ ├── 02_coverage_ch04.txt │ │ ├── 03_coverage_ch11.txt │ │ ├── 04_coverage_config.htm │ │ ├── 04_coverage_data.htm │ │ ├── 04_coverage_exchange.htm │ │ ├── 04_coverage_get.htm │ │ ├── 04_coverage_list.htm │ │ ├── 04_coverage_main.htm │ │ ├── 04_coverage_register.htm │ │ └── 04_coverage_rest.htm │ ├── 04_new_service/ │ │ └── 01_data_with_cache/ │ │ ├── dao.go │ │ ├── data.go │ │ └── internal/ │ │ ├── cache/ │ │ │ └── cache.go │ │ └── logging/ │ │ └── logging.go │ ├── acme/ │ │ ├── internal/ │ │ │ ├── config/ │ │ │ │ ├── config.go │ │ │ │ └── config_test.go │ │ │ ├── logging/ │ │ │ │ └── logging.go │ │ │ ├── modules/ │ │ │ │ ├── data/ │ │ │ │ │ ├── dao.go │ │ │ │ │ ├── data.go │ │ │ │ │ ├── data_test.go │ │ │ │ │ └── tracker.go │ │ │ │ ├── exchange/ │ │ │ │ │ ├── converter.go │ │ │ │ │ ├── converter_ext_bounday_test.go │ │ │ │ │ └── converter_int_bounday_test.go │ │ │ │ ├── get/ │ │ │ │ │ ├── get.go │ │ │ │ │ ├── go_test.go │ │ │ │ │ └── mock_my_loader_test.go │ │ │ │ ├── list/ │ │ │ │ │ ├── list.go │ │ │ │ │ ├── list_test.go │ │ │ │ │ └── mock_my_loader_test.go │ │ │ │ └── register/ │ │ │ │ ├── mock_exchanger_test.go │ │ │ │ ├── mock_my_saver_test.go │ │ │ │ ├── register.go │ │ │ │ └── register_test.go │ │ │ └── rest/ │ │ │ ├── get.go │ │ │ ├── get_test.go │ │ │ ├── list.go │ │ │ ├── list_test.go │ │ │ ├── mock_get_model_test.go │ │ │ ├── mock_list_model_test.go │ │ │ ├── mock_register_model_test.go │ │ │ ├── not_found.go │ │ │ ├── not_found_test.go │ │ │ ├── register.go │ │ │ ├── register_test.go │ │ │ └── server.go │ │ ├── main.go │ │ ├── main_test.go │ │ ├── wire.go │ │ └── wire_gen.go │ └── fake.go ├── default-config.json ├── fake.go ├── resources/ │ └── create.sql └── vendor/ ├── github.com/ │ ├── DATA-DOG/ │ │ └── go-sqlmock/ │ │ ├── LICENSE │ │ ├── README.md │ │ ├── argument.go │ │ ├── driver.go │ │ ├── expectations.go │ │ ├── expectations_before_go18.go │ │ ├── expectations_go18.go │ │ ├── result.go │ │ ├── rows.go │ │ ├── rows_go18.go │ │ ├── sqlmock.go │ │ ├── sqlmock_go18.go │ │ ├── statement.go │ │ └── util.go │ ├── davecgh/ │ │ └── go-spew/ │ │ ├── LICENSE │ │ └── spew/ │ │ ├── bypass.go │ │ ├── bypasssafe.go │ │ ├── common.go │ │ ├── config.go │ │ ├── doc.go │ │ ├── dump.go │ │ ├── format.go │ │ └── spew.go │ ├── go-sql-driver/ │ │ └── mysql/ │ │ ├── AUTHORS │ │ ├── CHANGELOG.md │ │ ├── CONTRIBUTING.md │ │ ├── LICENSE │ │ ├── README.md │ │ ├── appengine.go │ │ ├── buffer.go │ │ ├── collations.go │ │ ├── connection.go │ │ ├── connection_go18.go │ │ ├── const.go │ │ ├── driver.go │ │ ├── dsn.go │ │ ├── errors.go │ │ ├── fields.go │ │ ├── infile.go │ │ ├── packets.go │ │ ├── result.go │ │ ├── rows.go │ │ ├── statement.go │ │ ├── transaction.go │ │ ├── utils.go │ │ ├── utils_go17.go │ │ └── utils_go18.go │ ├── google/ │ │ └── wire/ │ │ ├── AUTHORS │ │ ├── CODE_OF_CONDUCT.md │ │ ├── CONTRIBUTING.md │ │ ├── CONTRIBUTORS │ │ ├── LICENSE │ │ ├── README.md │ │ ├── go.mod │ │ ├── go.sum │ │ └── wire.go │ ├── gorilla/ │ │ ├── context/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── context.go │ │ │ └── doc.go │ │ └── mux/ │ │ ├── ISSUE_TEMPLATE.md │ │ ├── LICENSE │ │ ├── README.md │ │ ├── context_gorilla.go │ │ ├── context_native.go │ │ ├── doc.go │ │ ├── middleware.go │ │ ├── mux.go │ │ ├── regexp.go │ │ ├── route.go │ │ └── test_helpers.go │ ├── pmezard/ │ │ └── go-difflib/ │ │ ├── LICENSE │ │ └── difflib/ │ │ └── difflib.go │ └── stretchr/ │ ├── objx/ │ │ ├── LICENSE.md │ │ ├── README.md │ │ ├── accessors.go │ │ ├── constants.go │ │ ├── conversions.go │ │ ├── doc.go │ │ ├── map.go │ │ ├── mutations.go │ │ ├── security.go │ │ ├── tests.go │ │ ├── type_specific_codegen.go │ │ └── value.go │ └── testify/ │ ├── LICENSE │ ├── assert/ │ │ ├── assertion_format.go │ │ ├── assertion_format.go.tmpl │ │ ├── assertion_forward.go │ │ ├── assertion_forward.go.tmpl │ │ ├── assertions.go │ │ ├── doc.go │ │ ├── errors.go │ │ ├── forward_assertions.go │ │ └── http_assertions.go │ ├── mock/ │ │ ├── doc.go │ │ └── mock.go │ └── require/ │ ├── doc.go │ ├── forward_requirements.go │ ├── require.go │ ├── require.go.tmpl │ ├── require_forward.go │ ├── require_forward.go.tmpl │ └── requirements.go ├── go.uber.org/ │ └── dig/ │ ├── CHANGELOG.md │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── check_license.sh │ ├── cycle.go │ ├── dig.go │ ├── doc.go │ ├── error.go │ ├── glide.yaml │ ├── internal/ │ │ ├── digreflect/ │ │ │ └── func.go │ │ └── dot/ │ │ └── graph.go │ ├── param.go │ ├── result.go │ ├── stringer.go │ ├── types.go │ └── version.go ├── google.golang.org/ │ └── appengine/ │ ├── LICENSE │ └── cloudsql/ │ ├── cloudsql.go │ ├── cloudsql_classic.go │ └── cloudsql_vm.go └── vendor.json ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Created by .ignore support plugin (hsz.mobi) ### Go template # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out config.json ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2018 Packt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Hands-On Dependency Injection in Go Hands-On Dependency Injection in Go This is the code repository for [Hands-On Dependency Injection in Go](https://www.packtpub.com/application-development/hands-dependency-injection-go?utm_source=github&utm_medium=repository&utm_campaign=9781789132762 ), published by Packt. **Develop clean Go code that is easier to read, maintain, and test** ## What is this book about? Hands-On Dependency Injection in Go takes you on a journey, refactoring existing code to adopt dependency injection (DI) using various methods available in Go. This book covers the following exciting features: * Understand the benefits of dependency injection * Explore SOLID design principles and how they relate to Go * Analyze various dependency injection patterns available in Go * Leverage DI to produce high quality, loosely coupled Go code * Refactor existing Go code to adopt dependency injection * Discover tools to improve your code's testability and test coverage * Generate and interpret Go dependency graphs If you feel this book is for you, get your [copy](https://www.amazon.com/dp/1789132762) today! https://www.packtpub.com/ ## Instructions and Navigations All of the code is organized into folders. For example, ch02. The code will look like the following: ``` html, body, #map { height: 100%; margin: 0; padding: 0 } ``` **Following is what you need for this book:** Hands-On Dependency Injection in Go is for programmers with a few year s experience in any language and a basic understanding of Go. If you wish to produce clean, loosely coupled code that is inherently easier to test, this book is for you. ## Getting the source The easiest way to obtain the source code is to use `go get`. This will ensure that the code is placed in the correct directory and should be then runnable and testable. To download this repo use `go get github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/...` ## Code Organization In this repository, there is 1 folder for every chapter of the book, named chXX where XX is the chapter number. The code provided are expanded versions of the code presented in the book. While it will compile and typically will not throw an error when passed into `go test` it is not designed to be executed. From chapter 4 onwards, there is an `acme` directory included with the code that chapter. The `acme` directory is the code for the sample service presented in the book with the changes discussed in that chapter already applied. You will also find 2 additional directories in the root of the repository: * **resources** - this directory contains an SQL file that should be used to populate a MySQL database. This database is used by the sample service * **vendor** - this is standard go vendor directory which contains the external packages required by the sample service ## Setting up the MySQL database The easiest way to create and populate the database required by the sample service is by running the following: `mysql < ./resources/create.sql` Depending on your settings you may want to provide a username and password like this: `mysql -u [your username] -p < ./resources/create.sql` This will create a database called `acme` with 1 table and 4 records. ## Creating a free account on CurrencyLayer The sample service uses a free currency conversion service. In order to successfully run all the examples, you will need to sign up [here](https://currencylayer.com/) and obtain an API Key. ## Configuring the sample service Now that you have your MySQL and CurrencyLayer credentials you can create a config for the sample service. 1. Copy `default-config.json` (found next to this file) to `config.json` 1. Open `config.json` in your favorite editor 1. Add your database credentials to the `"dsn"` setting. Should be in the form: `"[username]:[password]@tcp(localhost:3306)/[database name]?autocommit=true"` 1. Add your API Key to the `"exchangeRateAPIKey"` setting. Should be in the form: `"1234567890abcdef1234567890abcdef"` ## Running the sample service for a particular chapter To run sample service for a particular chapter: 1. First make sure you are in the base of this repository: `cd $GOPATH/src/github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/` 1. Use a command similar to the following (which is for ch04): `ACME_CONFIG=$GOPATH/src/github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/config.json go run ./ch04/acme/main.go` ### Special instructions for chapters 10-12 As we have multiple files and tests in the `main` package, we cannot use the standard `go run ./ch10/acme/main.go` to run the service. Instead we need to modify the command to `go run ./ch10/acme/main.go ./ch10/acme/wire_gen.go` ## Running tests for a chapter To run sample service for a particular chapter you can use a command similar to the follow (which is for ch04): 1. First make sure you are in the base of this repository: `cd $GOPATH/src/github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/` 1. Use a command similar to the following (which is for ch04): `ACME_CONFIG=$GOPATH/src/github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/config.json go test ./ch04/...` With the following software and hardware list you can run all code files present in the book (Chapter 1-12). ### Software and Hardware List | Chapter | Software required | OS required | | -------- | ------------------------------------ | -----------------------------------| | 1-12 | Go 1.10.x+ | Windows, Mac OS X, and Linux (Any) | | 4-12 | MySQL 5.7.x+ | Windows, Mac OS X, and Linux (Any) | | 4-12 | CurrencyLayer | Windows, Mac OS X, and Linux (Any) | ### Related products * Mastering Go [[Packt]](https://www.packtpub.com/networking-and-servers/mastering-go?utm_source=github&utm_medium=repository&utm_campaign=9781788626545 ) [[Amazon]](https://www.amazon.com/dp/1788626540) * Go Standard Library Cookbook [[Packt]](https://www.packtpub.com/application-development/go-standard-library-cookbook?utm_source=github&utm_medium=repository&utm_campaign=9781788475273 ) [[Amazon]](https://www.amazon.com/dp/1788475275) ## Get to Know the Author **Corey Scott** is a senior software engineer currently living in Melbourne, Australia. He’s been programming professionally since 2000, with the last 5 years spent building large-scale distributed services in Go. An occasional technical speaker and blogger on a variety of software-related topics, he is passionate about designing and building quality software. He believes that software engineering is a craft that should be honed, debated, and continuously improved. He takes a pragmatic, non-zealot approach to coding and is always up for a good debate about software engineering, continuous delivery, testing, or clean coding. ## Other books by the authors ### Suggestions and Feedback [Click here](https://docs.google.com/forms/d/e/1FAIpQLSdy7dATC6QmEL81FIUuymZ0Wy9vH1jHkvpY57OiMeKGqib_Ow/viewform) if you have any feedback or suggestions. ### Download a free PDF If you have already purchased a print or Kindle version of this book, you can get a DRM-free PDF version at no cost.
Simply click on the link to claim your free PDF.

https://packt.link/free-ebook/9781789132762

================================================ FILE: ch01/01_defining_depenency_injection/01_interface.go ================================================ package defining_depenency_injection import ( "encoding/json" "errors" ) // Saver persists the supplied bytes type Saver interface { Save(data []byte) error } // SavePerson will validate and persist the supplied person func SavePerson(person *Person, saver Saver) error { // validate the inputs err := person.validate() if err != nil { return err } // encode person to bytes bytes, err := person.encode() if err != nil { return err } // save the person and return the result return saver.Save(bytes) } // Person data object type Person struct { Name string Phone string } // validate the person object func (p *Person) validate() error { if p.Name == "" { return errors.New("name missing") } if p.Phone == "" { return errors.New("phone missing") } return nil } // convert the person into bytes func (p *Person) encode() ([]byte, error) { return json.Marshal(p) } ================================================ FILE: ch01/01_defining_depenency_injection/02_function_literal.go ================================================ package defining_depenency_injection import ( "errors" "fmt" ) // LoadPerson will load the requested person by ID. // Errors include: invalid ID, missing person and failure to load or decode. func LoadPerson(ID int, decodePerson func(data []byte) *Person) (*Person, error) { // validate the input if ID <= 0 { return nil, fmt.Errorf("invalid ID '%d' supplied", ID) } // load from storage bytes, err := loadPerson(ID) if err != nil { return nil, err } // decode bytes and return return decodePerson(bytes), nil } // load person as bytes from storage func loadPerson(ID int) ([]byte, error) { // TODO: implement return nil, errors.New("not implemented") } ================================================ FILE: ch01/01_defining_depenency_injection/03_test_without_nfs_test.go ================================================ package defining_depenency_injection import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestSavePerson_happyPath(t *testing.T) { // input in := &Person{ Name: "Sophia", Phone: "0123456789", } // mock the NFS mockNFS := &mockSaver{} mockNFS.On("Save", mock.Anything).Return(nil).Once() // Call Save resultErr := SavePerson(in, mockNFS) // validate result assert.NoError(t, resultErr) assert.True(t, mockNFS.AssertExpectations(t)) } // mock implementation of Saver type mockSaver struct { mock.Mock } // Save implements Saver func (m *mockSaver) Save(data []byte) error { outputs := m.Mock.Called(data) return outputs.Error(0) } ================================================ FILE: ch01/01_defining_depenency_injection/04_fail_test_without_nfs_test.go ================================================ package defining_depenency_injection import ( "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestSavePerson_nfsAlwaysFails(t *testing.T) { // input in := &Person{ Name: "Sophia", Phone: "0123456789", } // mock the NFS mockNFS := &mockSaver{} mockNFS.On("Save", mock.Anything).Return(errors.New("save failed")).Once() // Call Save resultErr := SavePerson(in, mockNFS) // validate result assert.Error(t, resultErr) assert.True(t, mockNFS.AssertExpectations(t)) } ================================================ FILE: ch01/02_code_smells/01_code_bloat/01_switch_type.go ================================================ package code_bloat import ( "strconv" ) func AppendValue(buffer []byte, in interface{}) []byte { var value []byte // convert input to []byte switch concrete := in.(type) { case []byte: value = concrete case string: value = []byte(concrete) case int64: value = []byte(strconv.FormatInt(concrete, 10)) case bool: value = []byte(strconv.FormatBool(concrete)) case float64: value = []byte(strconv.FormatFloat(concrete, 'e', 3, 64)) } buffer = append(buffer, value...) return buffer } ================================================ FILE: ch01/02_code_smells/02_resistance_to_change/01_shotgun_surgey.go ================================================ package _2_resistance_to_change import ( "database/sql" "io" ) // Renderer will render a person to the supplied writer type Renderer struct{} func (r Renderer) render(name, phone string, output io.Writer) { // output the person } // Validator will validate the supplied person has all the required fields type Validator struct{} func (v Validator) validate(name, phone string) error { // validate the person return nil } // Saver will save the supplied person to the DB type Saver struct{} func (s *Saver) Save(db *sql.DB, name, phone string) { // save the person to db } ================================================ FILE: ch01/02_code_smells/03_wasted_effort/01_excessive_comments.go ================================================ package wasted_effort // Excessive comments func outputOrderedPeopleA(in []*Person) { // This code orders people by name. // In cases where the name is the same, it will order by phone number. // The sort algorithm used is a bubble sort // WARNING: this sort will change the items of the input array for range in { // ... sort code removed ... } outputPeople(in) } // Comments replaced with descriptive names func outputOrderedPeopleB(in []*Person) { sortPeople(in) outputPeople(in) } func outputPeople(in []*Person) { // TODO: implement } // any special instructions that MUST be documented relating to the sort should go here func sortPeople(in []*Person) { // TODO: implement } // Person data object type Person struct { Name string Phone string } ================================================ FILE: ch01/02_code_smells/03_wasted_effort/02_complicated_go.go ================================================ package wasted_effort import ( "image" "image/color" "math" ) func d(r, v float64, i *image.RGBA, c color.Color) { for a := float64(0); a < 360; a++ { ra := math.Pi * 2 * a / 360 x := r*math.Sin(ra) + v y := r*math.Cos(ra) + v i.Set(int(x), int(y), c) } } ================================================ FILE: ch01/02_code_smells/04_tight_coupling/01_circular_dependencies/config/config.go ================================================ // +build bad package config import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch01/02_code_smells/04_tight_coupling/01_circular_dependencies/payment" ) // Config defines the JSON format of the config file type Config struct { // Address is the host and port to bind to. // Default 0.0.0.0:8080 Address string // DefaultCurrency is the default currency of the system DefaultCurrency payment.Currency } // Load will load the JSON config from the file supplied func Load(filename string) (*Config, error) { // TODO: load currency from file return nil, errors.New("not implemented yet") } ================================================ FILE: ch01/02_code_smells/04_tight_coupling/01_circular_dependencies/payment/currency.go ================================================ // +build bad package payment import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch01/02_code_smells/04_tight_coupling/01_circular_dependencies/config" ) // Currency is custom type for currency type Currency string // Processor processes payments type Processor struct { Config *config.Config } // Pay makes a payment in the default currency func (p *Processor) Pay(amount float64) error { // TODO: implement me return errors.New("not implemented yet") } ================================================ FILE: ch01/02_code_smells/04_tight_coupling/02_object_orgy.go ================================================ package _4_tight_coupling import ( "errors" "io/ioutil" "net/http" ) type PageLoader struct { } func (o *PageLoader) LoadPage(url string) ([]byte, error) { b := newFetcher() // check cache payload, err := b.cache.Get(url) if err == nil { // found in cache return payload, nil } // call upstream resp, err := b.httpClient.Get(url) if err != nil { return nil, err } defer resp.Body.Close() // extract data from HTTP response payload, err = ioutil.ReadAll(resp.Body) if err != nil { return nil, err } // save to cache asynchronously go func(key string, value []byte) { b.cache.Set(key, value) }(url, payload) // return return payload, nil } type Fetcher struct { httpClient http.Client cache *Cache } func newFetcher() *Fetcher { return &Fetcher{} } type Cache struct { // not implemented } func (c *Cache) Get(key string) ([]byte, error) { // not implemented return nil, errors.New("not implemented") } func (c *Cache) Set(key string, data []byte) error { // not implemented return errors.New("not implemented") } ================================================ FILE: ch01/02_code_smells/04_tight_coupling/03_feature_envy.go ================================================ package _4_tight_coupling import ( "errors" "time" ) type searchRequest struct { query string start time.Time end time.Time } func (request searchRequest) validate() error { if request.query == "" { return errors.New("search term is missing") } if request.start.IsZero() || request.start.After(time.Now()) { return errors.New("start time is missing or invalid") } if request.end.IsZero() || request.end.Before(request.start) { return errors.New("end time is missing or invalid") } return nil } type searchResults struct { result string } func doSearchWithEnvy(request searchRequest) ([]searchResults, error) { // validate request if request.query == "" { return nil, errors.New("search term is missing") } if request.start.IsZero() || request.start.After(time.Now()) { return nil, errors.New("start time is missing or invalid") } if request.end.IsZero() || request.end.Before(request.start) { return nil, errors.New("end time is missing or invalid") } return performSearch(request) } func doSearchWithoutEnvy(request searchRequest) ([]searchResults, error) { err := request.validate() if err != nil { return nil, err } return performSearch(request) } func performSearch(request searchRequest) ([]searchResults, error) { // TODO: implement return nil, errors.New("not implemented") } ================================================ FILE: ch02/01_single_responsibility_principle/01_responsibility_vs_change.go ================================================ package srp import ( "fmt" "io" ) // CalculatorV1 calculates the test coverage for a directory and it's sub-directories type CalculatorV1 struct { // coverage data populated by `Calculate()` method data map[string]float64 } // Calculate will calculate the coverage func (c *CalculatorV1) Calculate(path string) error { // run `go test -cover ./[path]/...` and store the results return nil } // Output will print the coverage data to the supplied writer func (c *CalculatorV1) Output(writer io.Writer) { for path, result := range c.data { fmt.Fprintf(writer, "%s -> %.1f\n", path, result) } } ================================================ FILE: ch02/01_single_responsibility_principle/02_responsibility_vs_change.go ================================================ package srp import ( "fmt" "io" ) // CalculatorV2 calculates the test coverage for a directory and it's sub-directories type CalculatorV2 struct { // coverage data populated by `Calculate()` method data map[string]float64 } // Calculate will calculate the coverage func (c *CalculatorV2) Calculate(path string) error { // run `go test -cover ./[path]/...` and store the results return nil } // Output will print the coverage data to the supplied writer func (c CalculatorV2) Output(writer io.Writer) { for path, result := range c.data { fmt.Fprintf(writer, "%s -> %.1f\n", path, result) } } // OutputCSV will print the coverage data to the supplied writer func (c CalculatorV2) OutputCSV(writer io.Writer) { for path, result := range c.data { fmt.Fprintf(writer, "%s,%.1f\n", path, result) } } ================================================ FILE: ch02/01_single_responsibility_principle/03_responsibility_vs_change.go ================================================ package srp import ( "fmt" "io" ) // CalculatorV3 calculates the test coverage for a directory and it's sub-directories type CalculatorV3 struct { // coverage data populated by `Calculate()` method data map[string]float64 } // Calculate will calculate the coverage func (c *CalculatorV3) Calculate(path string) error { // run `go test -cover ./[path]/...` and store the results return nil } func (c *CalculatorV3) getData() map[string]float64 { // copy and return the map return nil } type Printer interface { Output(data map[string]float64) } type DefaultPrinter struct { Writer io.Writer } // Output implements Printer func (d *DefaultPrinter) Output(data map[string]float64) { for path, result := range data { fmt.Fprintf(d.Writer, "%s -> %.1f\n", path, result) } } type CSVPrinter struct { Writer io.Writer } // Output implements Printer func (d *CSVPrinter) Output(data map[string]float64) { for path, result := range data { fmt.Fprintf(d.Writer, "%s,%.1f\n", path, result) } } ================================================ FILE: ch02/01_single_responsibility_principle/04_long_method.go ================================================ package srp import ( "database/sql" "encoding/json" "net/http" "strconv" ) func loadUserHandlerLong(resp http.ResponseWriter, req *http.Request) { err := req.ParseForm() if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } userID, err := strconv.ParseInt(req.Form.Get("UserID"), 10, 64) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } row := DB.QueryRow("SELECT * FROM Users WHERE ID = ?", userID) person := &Person{} err = row.Scan(&person.ID, &person.Name, &person.Phone) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } encoder := json.NewEncoder(resp) err = encoder.Encode(person) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } } var DB *sql.DB type Person struct { ID int64 Name string Phone string } ================================================ FILE: ch02/01_single_responsibility_principle/04_long_method_test.go ================================================ package srp import ( "net/http" "net/http/httptest" "net/url" "os" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" ) func TestLoadUserHandler(t *testing.T) { // build request req := &http.Request{ Form: url.Values{}, } req.Form.Add("UserID", "1234") // call function under test resp := httptest.NewRecorder() loadUserHandlerLong(resp, req) // validate result assert.Equal(t, http.StatusOK, resp.Code) expectedBody := `{"ID":1,"Name":"Bob","Phone":"0123456789"}` + "\n" assert.Equal(t, expectedBody, resp.Body.String()) } func TestMain(m *testing.M) { // create fake DB for this test var mock sqlmock.Sqlmock DB, mock, _ = sqlmock.New() // config fake response mock.ExpectQuery(".*").WillReturnRows( sqlmock.NewRows([]string{"ID", "Name", "Phone"}).AddRow( 1, "Bob", "0123456789")) os.Exit(m.Run()) } ================================================ FILE: ch02/01_single_responsibility_principle/05_srp_method.go ================================================ package srp import ( "encoding/json" "net/http" "strconv" ) func loadUserHandlerSRP(resp http.ResponseWriter, req *http.Request) { userID, err := extractIDFromRequest(req) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } person, err := loadPersonByID(userID) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } outputPerson(resp, person) } func extractIDFromRequest(req *http.Request) (int64, error) { err := req.ParseForm() if err != nil { return 0, err } return strconv.ParseInt(req.Form.Get("UserID"), 10, 64) } func loadPersonByID(userID int64) (*Person, error) { row := DB.QueryRow("SELECT * FROM Users WHERE userID = ?", userID) person := &Person{} err := row.Scan(person.ID, person.Name, person.Phone) if err != nil { return nil, err } return person, nil } func outputPerson(resp http.ResponseWriter, person *Person) { encoder := json.NewEncoder(resp) err := encoder.Encode(person) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } } ================================================ FILE: ch02/02_open_closed_principle/01_open_closed_failure.go ================================================ package ocp import ( "io" "net/http" ) func BuildOutputOCPFail(response http.ResponseWriter, format string, person Person) { var err error switch format { case "csv": err = outputCSV(response, person) case "json": err = outputJSON(response, person) } if err != nil { // output a server error and quit response.WriteHeader(http.StatusInternalServerError) return } response.WriteHeader(http.StatusOK) } // output the person as CSV and return error when failing to do so func outputCSV(writer io.Writer, person Person) error { // TODO: implement return nil } // output the person as JSON and return error when failing to do so func outputJSON(writer io.Writer, person Person) error { // TODO: implement return nil } // A data transfer object that represents a person type Person struct { Name string Email string } ================================================ FILE: ch02/02_open_closed_principle/02_open_closed_success.go ================================================ package ocp import ( "io" "net/http" ) func BuildOutputOCPSuccess(response http.ResponseWriter, formatter PersonFormatter, person Person) { err := formatter.Format(response, person) if err != nil { // output a server error and quit response.WriteHeader(http.StatusInternalServerError) return } response.WriteHeader(http.StatusOK) } type PersonFormatter interface { Format(writer io.Writer, person Person) error } // output the person as CSV type CSVPersonFormatter struct{} // Format implements the PersonFormatter interface func (c *CSVPersonFormatter) Format(writer io.Writer, person Person) error { // TODO: implement return nil } // output the person as JSON type JSONPersonFormatter struct{} // Format implements the PersonFormatter interface func (j *JSONPersonFormatter) Format(writer io.Writer, person Person) error { // TODO: implement return nil } ================================================ FILE: ch02/02_open_closed_principle/03_shotgun_surgery.go ================================================ package ocp import ( "net/http" "strconv" ) func GetUserHandlerV1(resp http.ResponseWriter, req *http.Request) { // validate inputs err := req.ParseForm() if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } userID, err := strconv.ParseInt(req.Form.Get("UserID"), 10, 64) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } user := loadUser(userID) outputUser(resp, user) } func DeleteUserHandlerV1(resp http.ResponseWriter, req *http.Request) { // validate inputs err := req.ParseForm() if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } userID, err := strconv.ParseInt(req.Form.Get("UserID"), 10, 64) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } deleteUser(userID) } func loadUser(userID int64) interface{} { // TODO: implement return nil } func deleteUser(userID int64) { // TODO: implement } func outputUser(resp http.ResponseWriter, user interface{}) { // TODO: implement } ================================================ FILE: ch02/02_open_closed_principle/04_after_shotgun_surgery.go ================================================ package ocp import ( "errors" "net/http" "net/url" "strconv" ) func GetUserHandlerV2(resp http.ResponseWriter, req *http.Request) { // validate inputs err := req.ParseForm() if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } userID, err := extractUserID(req.Form) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } user := loadUser(userID) outputUser(resp, user) } func DeleteUserHandlerV2(resp http.ResponseWriter, req *http.Request) { // validate inputs err := req.ParseForm() if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } userID, err := extractUserID(req.Form) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } deleteUser(userID) } func extractUserID(values url.Values) (int64, error) { userID, err := strconv.ParseInt(values.Get("UserID"), 10, 64) if err != nil { return 0, err } if userID <= 0 { return 0, errors.New("userID must be positive") } return userID, nil } ================================================ FILE: ch02/02_open_closed_principle/05_composition.go ================================================ package ocp import ( "database/sql" ) type rowConverter struct { } // populate the supplied Person from *sql.Row or *sql.Rows object func (d *rowConverter) populate(in *Person, scan func(dest ...interface{}) error) error { return scan(in.Name, in.Email) } type LoadPerson struct { // compose the row converter into this loader rowConverter } func (loader *LoadPerson) ByID(id int) (Person, error) { row := loader.loadFromDB(id) person := Person{} // call the composed "abstract class" err := loader.populate(&person, row.Scan) return person, err } func (loader *LoadPerson) loadFromDB(id int) *sql.Row { // TODO: implement return nil } type LoadAll struct { // compose the row converter into this loader rowConverter } func (loader *LoadPerson) All() ([]Person, error) { rows := loader.loadAllFromDB() defer rows.Close() var output []Person for rows.Next() { person := Person{} // call the composed "abstract class" err := loader.populate(&person, rows.Scan) if err != nil { return nil, err } } return output, nil } func (loader *LoadPerson) loadAllFromDB() *sql.Rows { // TODO: implement return nil } ================================================ FILE: ch02/02_open_closed_principle/06_handler_struct.go ================================================ package ocp import ( "net/http" ) // a HTTP health check handler in long form type healthCheckLong struct { } func (h *healthCheckLong) ServeHTTP(resp http.ResponseWriter, _ *http.Request) { resp.WriteHeader(http.StatusNoContent) } func healthCheckLongUsage() { http.Handle("/health", &healthCheckLong{}) } ================================================ FILE: ch02/02_open_closed_principle/07_handler_func.go ================================================ package ocp import ( "net/http" ) // a HTTP health check handler in short form func healthCheckShort(resp http.ResponseWriter, _ *http.Request) { resp.WriteHeader(http.StatusNoContent) } func healthCheckShortUsage() { http.Handle("/health", http.HandlerFunc(healthCheckShort)) } ================================================ FILE: ch02/03_liskov_substitution_principle/01_violation/example.go ================================================ package lsp_violation func Go(vehicle actions) { if sled, ok := vehicle.(*Sled); ok { sled.pushStart() } else { vehicle.startEngine() } vehicle.drive() } type actions interface { drive() startEngine() } type Vehicle struct { } func (v Vehicle) drive() { // TODO: implement } func (v Vehicle) startEngine() { // TODO: implement } func (v Vehicle) stopEngine() { // TODO: implement } type Car struct { Vehicle } type Sled struct { Vehicle } func (s Sled) startEngine() { // override so that is does nothing } func (s Sled) stopEngine() { // override so that is does nothing } func (s Sled) pushStart() { // TODO: implement } ================================================ FILE: ch02/03_liskov_substitution_principle/02_fixed/example.go ================================================ package fixedv1 func Go(vehicle actions) { switch concrete := vehicle.(type) { case poweredActions: concrete.startEngine() case unpoweredActions: concrete.pushStart() } vehicle.drive() } type actions interface { drive() } type poweredActions interface { actions startEngine() stopEngine() } type unpoweredActions interface { actions pushStart() } type Vehicle struct { } func (v Vehicle) drive() { // TODO: implement } type PoweredVehicle struct { Vehicle } func (v PoweredVehicle) startEngine() { // common engine start code } type Car struct { PoweredVehicle } type Sled struct { Vehicle } func (s Sled) pushStart() { // do nothing } ================================================ FILE: ch02/03_liskov_substitution_principle/03_fixed/example.go ================================================ package fixedv2 func Go(vehicle actions) { vehicle.start() vehicle.drive() } type actions interface { start() drive() } type Car struct { poweredVehicle } func (c Car) start() { c.poweredVehicle.startEngine() } func (c Car) drive() { // TODO: implement } type poweredVehicle struct { } func (p poweredVehicle) startEngine() { // common engine start code } type Sled struct { } func (s Sled) start() { // push start } func (s Sled) drive() { // TODO: implement } ================================================ FILE: ch02/03_liskov_substitution_principle/04_behaviour.go ================================================ package lsp type Collection interface { Add(item interface{}) Get(index int) interface{} } type CollectionImpl struct { items []interface{} } func (c *CollectionImpl) Add(item interface{}) { c.items = append(c.items, item) } func (c *CollectionImpl) Get(index int) interface{} { return c.items[index] } type ReadOnlyCollection struct { CollectionImpl } func (ro *ReadOnlyCollection) Add(item interface{}) { // intentionally does nothing } ================================================ FILE: ch02/03_liskov_substitution_principle/05_behaviour_fixed.go ================================================ package lsp type ImmutableCollection interface { Get(index int) interface{} } type MutableCollection interface { ImmutableCollection Add(item interface{}) } type ReadOnlyCollectionV2 struct { items []interface{} } func (ro *ReadOnlyCollectionV2) Get(index int) interface{} { return ro.items[index] } type CollectionImplV2 struct { ReadOnlyCollectionV2 } func (c *CollectionImplV2) Add(item interface{}) { c.items = append(c.items, item) } ================================================ FILE: ch02/04_interface_segregation_principle/01_fat_interface.go ================================================ package isp import ( "context" ) type Item struct { Key string Payload []byte } type FatDbInterface interface { BatchGetItem(IDs ...int) ([]Item, error) BatchGetItemWithContext(ctx context.Context, IDs ...int) ([]Item, error) BatchPutItem(items ...Item) error BatchPutItemWithContext(ctx context.Context, items ...Item) error DeleteItem(ID int) error DeleteItemWithContext(ctx context.Context, item Item) error GetItem(ID int) (Item, error) GetItemWithContext(ctx context.Context, ID int) (Item, error) PutItem(item Item) error PutItemWithContext(ctx context.Context, item Item) error Query(query string, args ...interface{}) ([]Item, error) QueryWithContext(ctx context.Context, query string, args ...interface{}) ([]Item, error) UpdateItem(item Item) error UpdateItemWithContext(ctx context.Context, item Item) error } type Cache struct { db FatDbInterface } func (c *Cache) Get(key string) interface{} { // code removed // load from DB _, _ = c.db.GetItem(42) // code removed return nil } func (c *Cache) Set(key string, value interface{}) { // code removed // save to DB _ = c.db.PutItem(Item{}) // code removed } ================================================ FILE: ch02/04_interface_segregation_principle/02_thin_interface.go ================================================ package isp type myDB interface { GetItem(ID int) (Item, error) PutItem(item Item) error } type CacheV2 struct { db myDB } func (c *CacheV2) Get(key string) interface{} { // code removed // load from DB _, _ = c.db.GetItem(42) // code removed return nil } func (c *CacheV2) Set(key string, value interface{}) { // code removed // save from DB _ = c.db.PutItem(Item{}) // code removed } ================================================ FILE: ch02/04_interface_segregation_principle/03_repeated_inputs.go ================================================ package isp import ( "context" "errors" ) func Encrypt(ctx context.Context, data []byte) ([]byte, error) { // As this operation make take too long, we need to be able to kill it stop := ctx.Done() result := make(chan []byte, 1) go func() { defer close(result) // pull the encryption key from context keyRaw := ctx.Value("encryption-key") if keyRaw == nil { panic("encryption key not found in context") } key := keyRaw.([]byte) // perform encryption ciperText := performEncryption(key, data) // signal complete by sending the result result <- ciperText }() select { case ciperText := <-result: // happy path return ciperText, nil case <-stop: // cancelled return nil, errors.New("operation cancelled") } } func performEncryption(key []byte, data []byte) []byte { // TODO: implement return nil } ================================================ FILE: ch02/04_interface_segregation_principle/04_repeated_inputs.go ================================================ package isp import ( "errors" ) type Value interface { Value(key interface{}) interface{} } type Monitor interface { Done() <-chan struct{} } func EncryptV2(keyValue Value, monitor Monitor, data []byte) ([]byte, error) { // As this operation make take too long, we need to be able to kill it stop := monitor.Done() result := make(chan []byte, 1) go func() { defer close(result) // pull the encryption key from Value keyRaw := keyValue.Value("encryption-key") if keyRaw == nil { panic("encryption key not found in context") } key := keyRaw.([]byte) // perform encryption ciperText := performEncryption(key, data) // signal complete by sending the result result <- ciperText }() select { case ciperText := <-result: // happy path return ciperText, nil case <-stop: // cancelled return nil, errors.New("operation cancelled") } } ================================================ FILE: ch02/04_interface_segregation_principle/05_repeated_inputs.go ================================================ package isp import ( "context" ) func UseEncryptV2() { // create a context ctx, cancel := context.WithCancel(context.Background()) defer cancel() // store the key ctx = context.WithValue(ctx, "encryption-key", "-secret-") // call the function _, _ = EncryptV2(ctx, ctx, []byte("my data")) } ================================================ FILE: ch02/04_interface_segregation_principle/06_implicit_interfaces.go ================================================ package isp import ( "fmt" ) type Talker interface { SayHello() string } type Dog struct{} // The method implicitly implements the Talker interface func (d Dog) SayHello() string { return "Woof!" } func Speak() { var talker Talker talker = Dog{} fmt.Print(talker.SayHello()) } ================================================ FILE: ch03/01_optimizing_for_humans/01_not_so_simple.go ================================================ package humans import ( "bytes" "strconv" "strings" ) func NotSoSimple(ID int64, name string, age int, registered bool) string { out := &bytes.Buffer{} out.WriteString(strconv.FormatInt(ID, 10)) out.WriteString("-") out.WriteString(strings.Replace(name, " ", "_", -1)) out.WriteString("-") out.WriteString(strconv.Itoa(age)) out.WriteString("-") out.WriteString(strconv.FormatBool(registered)) return out.String() } ================================================ FILE: ch03/01_optimizing_for_humans/02_start_simple.go ================================================ package humans import ( "fmt" "strings" ) func Simpler(ID int64, name string, age int, registered bool) string { nameWithNoSpaces := strings.Replace(name, " ", "_", -1) return fmt.Sprintf("%d-%s-%d-%t", ID, nameWithNoSpaces, age, registered) } ================================================ FILE: ch03/01_optimizing_for_humans/03_too_abstract.go ================================================ package humans import ( "io/ioutil" "net/http" ) type myGetter interface { Get(url string) (*http.Response, error) } func TooAbstract(getter myGetter, url string) ([]byte, error) { resp, err := getter.Get(url) if err != nil { return nil, err } defer resp.Body.Close() return ioutil.ReadAll(resp.Body) } ================================================ FILE: ch03/01_optimizing_for_humans/04_common_concept.go ================================================ package humans import ( "io/ioutil" "net/http" ) func CommonConcept(url string) ([]byte, error) { resp, err := http.Get(url) if err != nil { return nil, err } defer resp.Body.Close() return ioutil.ReadAll(resp.Body) } ================================================ FILE: ch03/01_optimizing_for_humans/05_boolean_param.go ================================================ package humans import ( "time" ) type Pet struct { Name string Dog bool Born time.Time } func NewPet(name string, isDog bool) Pet { return Pet{ Name: name, Dog: isDog, Born: time.Now(), } } func CreatePetsV1() { NewPet("Fido", true) } ================================================ FILE: ch03/01_optimizing_for_humans/06_hidden_boolean.go ================================================ package humans const ( isDog = true isCat = false ) func NewDog(name string) Pet { return NewPet(name, isDog) } func NewCat(name string) Pet { return NewPet(name, isCat) } func CreatePetsV2() { NewDog("Fido") } ================================================ FILE: ch03/01_optimizing_for_humans/07_wide_formatter.go ================================================ package humans type WideFormatter interface { ToCSV(pets []Pet) ([]byte, error) ToGOB(pets []Pet) ([]byte, error) ToJSON(pets []Pet) ([]byte, error) } ================================================ FILE: ch03/01_optimizing_for_humans/08_thin_formatters.go ================================================ package humans type ThinFormatter interface { Format(pets []Pet) ([]byte, error) } type CSVFormatter struct{} func (f CSVFormatter) Format(pets []Pet) ([]byte, error) { // convert slice of pets to CSV return nil, nil } type GOBFormatter struct{} func (f GOBFormatter) Format(pets []Pet) ([]byte, error) { // convert slice of pets to GOB return nil, nil } type JSONFormatter struct{} func (f JSONFormatter) Format(pets []Pet) ([]byte, error) { // convert slice of pets to JSON return nil, nil } ================================================ FILE: ch03/01_optimizing_for_humans/09_extra_config.go ================================================ package humans // PetFetcher searches the data store for pets whos name matches the search string. // Limit is optional (default is 100). Offset is optional (default 0). // sortBy is optional (default name). sortAscending is optional func PetFetcher(search string, limit int, offset int, sortBy string, sortAscending bool) []Pet { return []Pet{} } func PetFetcherTypicalUsage() { _ = PetFetcher("Fido", 0, 0, "", true) } ================================================ FILE: ch03/02_unit_tests/01_loader.go ================================================ package unit_tests import ( "bytes" "errors" "fmt" "io" "testing" "github.com/stretchr/testify/assert" ) type Loader interface { Load(ID int) (*Pet, error) } func TestLoadAndPrint_happyPath(t *testing.T) { result := &bytes.Buffer{} LoadAndPrint(&happyPathLoader{}, 1, result) assert.Contains(t, result.String(), "Pet named") } func TestLoadAndPrint_notFound(t *testing.T) { result := &bytes.Buffer{} LoadAndPrint(&missingLoader{}, 1, result) assert.Contains(t, result.String(), "no such pet") } func TestLoadAndPrint_error(t *testing.T) { result := &bytes.Buffer{} LoadAndPrint(&errorLoader{}, 1, result) assert.Contains(t, result.String(), "failed to load") } func LoadAndPrint(loader Loader, ID int, dest io.Writer) { loadedPet, err := loader.Load(ID) if err != nil { fmt.Fprintf(dest, "failed to load pet with ID %d with error: %s", ID, err) return } if loadedPet == nil { fmt.Fprintf(dest, "no such pet found") return } fmt.Fprintf(dest, "Pet named %s loaded", loadedPet.Name) } // implements Loader type happyPathLoader struct { } func (l *happyPathLoader) Load(ID int) (*Pet, error) { return &Pet{}, nil } // implements Loader type missingLoader struct { } func (l *missingLoader) Load(ID int) (*Pet, error) { return nil, nil } // implements Loader type errorLoader struct { } func (l *errorLoader) Load(ID int) (*Pet, error) { return nil, errors.New("failed") } ================================================ FILE: ch03/02_unit_tests/02_language_feature.go ================================================ package unit_tests import ( "testing" "github.com/stretchr/testify/assert" ) type Pet struct { Name string } func NewPet(name string) *Pet { return &Pet{ Name: name, } } func TestLanguageFeatures(t *testing.T) { petFish := NewPet("Goldie") assert.IsType(t, &Pet{}, petFish) } ================================================ FILE: ch03/02_unit_tests/03_simple_test.go ================================================ package unit_tests import ( "testing" "github.com/stretchr/testify/assert" ) func concat(a, b string) string { return a + b } func TestTooSimple(t *testing.T) { a := "Hello " b := "World" expected := "Hello World" assert.Equal(t, expected, concat(a, b)) } ================================================ FILE: ch03/02_unit_tests/04_test_from_api.go ================================================ package unit_tests import ( "database/sql" ) type PetSaver struct{} // save the supplied pet and return the ID func (p PetSaver) Save(pet Pet) (int, error) { err := p.validate(pet) if err != nil { return 0, err } result, err := p.save(pet) if err != nil { return 0, err } return p.extractID(result) } // ensure the pet record is complete func (p PetSaver) validate(pet Pet) error { return nil } // save to the datastore func (p PetSaver) save(pet Pet) (sql.Result, error) { return nil, nil } // extract the ID from the result func (p PetSaver) extractID(result sql.Result) (int, error) { return 0, nil } ================================================ FILE: ch03/02_unit_tests/05_repeated_code.go ================================================ package unit_tests import ( "testing" "github.com/stretchr/testify/assert" ) // Round the supplied number to the nearest integer func Round(in float64) int { return 0 } func TestRound_down(t *testing.T) { in := float64(1.1) expected := 1 result := Round(in) assert.Equal(t, expected, result) } func TestRound_up(t *testing.T) { in := float64(3.7) expected := 4 result := Round(in) assert.Equal(t, expected, result) } func TestRound_noChange(t *testing.T) { in := float64(6.0) expected := 6 result := Round(in) assert.Equal(t, expected, result) } ================================================ FILE: ch03/02_unit_tests/06_tdt.go ================================================ package unit_tests import ( "testing" "github.com/stretchr/testify/assert" ) func TestRound(t *testing.T) { scenarios := []struct { desc string in float64 expected int }{ { desc: "round down", in: 1.1, expected: 1, }, { desc: "round up", in: 3.7, expected: 4, }, { desc: "unchanged", in: 6.0, expected: 6, }, } for _, scenario := range scenarios { in := scenario.in result := Round(in) assert.Equal(t, scenario.expected, result) } } ================================================ FILE: ch03/02_unit_tests/07_person_loader.go ================================================ package unit_tests import ( "errors" ) var ErrNotFound = errors.New("person not found") type Person struct { Name string } //go:generate mockery -name PersonLoader -testonly -inpkg -case=underscore type PersonLoader interface { Load(ID int) (*Person, error) } func LoadPersonName(loader PersonLoader, ID int) (string, error) { person, err := loader.Load(ID) if err != nil { return "", err } return person.Name, nil } ================================================ FILE: ch03/02_unit_tests/08_stub.go ================================================ package unit_tests // Stubbed implementation of PersonLoader type PersonLoaderStub struct { Person *Person Error error } func (p *PersonLoaderStub) Load(ID int) (*Person, error) { return p.Person, p.Error } ================================================ FILE: ch03/02_unit_tests/09_stub_tdt.go ================================================ package unit_tests import ( "errors" "testing" "github.com/stretchr/testify/assert" ) func TestLoadPersonNameStubs(t *testing.T) { // this value does not matter as the stub ignores it fakeID := 1 scenarios := []struct { desc string loaderStub *PersonLoaderStub expectedName string expectErr bool }{ { desc: "happy path", loaderStub: &PersonLoaderStub{ Person: &Person{Name: "Sophia"}, }, expectedName: "Sophia", expectErr: false, }, { desc: "input error", loaderStub: &PersonLoaderStub{ Error: ErrNotFound, }, expectedName: "", expectErr: true, }, { desc: "system error path", loaderStub: &PersonLoaderStub{ Error: errors.New("something failed"), }, expectedName: "", expectErr: true, }, } for _, scenario := range scenarios { result, resultErr := LoadPersonName(scenario.loaderStub, fakeID) assert.Equal(t, scenario.expectedName, result, scenario.desc) assert.Equal(t, scenario.expectErr, resultErr != nil, scenario.desc) } } ================================================ FILE: ch03/02_unit_tests/10_mocks.go ================================================ package unit_tests import ( "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestLoadPersonName(t *testing.T) { // this value does not matter as the stub ignores it fakeID := 1 scenarios := []struct { desc string configureMock func(stub *PersonLoaderMock) expectedName string expectErr bool }{ { desc: "happy path", configureMock: func(loaderMock *PersonLoaderMock) { loaderMock.On("Load", mock.Anything). Return(&Person{Name: "Sophia"}, nil). Once() }, expectedName: "Sophia", expectErr: false, }, { desc: "input error", configureMock: func(loaderMock *PersonLoaderMock) { loaderMock.On("Load", mock.Anything). Return(nil, ErrNotFound). Once() }, expectedName: "", expectErr: true, }, { desc: "system error path", configureMock: func(loaderMock *PersonLoaderMock) { loaderMock.On("Load", mock.Anything). Return(nil, errors.New("something failed")). Once() }, expectedName: "", expectErr: true, }, } for _, scenario := range scenarios { mockLoader := &PersonLoaderMock{} scenario.configureMock(mockLoader) result, resultErr := LoadPersonName(mockLoader, fakeID) assert.Equal(t, scenario.expectedName, result, scenario.desc) assert.Equal(t, scenario.expectErr, resultErr != nil, scenario.desc) assert.True(t, mockLoader.AssertExpectations(t), scenario.desc) } } // Mocked implementation of PersonLoader type PersonLoaderMock struct { mock.Mock } func (p *PersonLoaderMock) Load(ID int) (*Person, error) { outputs := p.Mock.Called(ID) person := outputs.Get(0) err := outputs.Error(1) if person != nil { return person.(*Person), err } return nil, err } ================================================ FILE: ch03/03_test_induced_damage/01_io_closer.go ================================================ package test_damage import ( "io" ) func WriteAndClose(destination io.WriteCloser, contents string) error { defer destination.Close() _, err := destination.Write([]byte(contents)) if err != nil { return err } return nil } ================================================ FILE: ch03/03_test_induced_damage/02_json.go ================================================ package test_damage import ( "encoding/json" "io" ) func PrintAsJSON(destination io.Writer, plant Plant) error { bytes, err := json.Marshal(plant) if err != nil { return err } destination.Write(bytes) return nil } type Plant struct { Name string } ================================================ FILE: ch03/04_visualizing_dependencies/depgraph.sh ================================================ #!/usr/bin/env bash # Note: # This script should be run in the base directory of the project/service # Inputs # # This cuts down on typing by allowing you to enter only the sub-directory you wish to graph; instead of the entire # package prefix="./" PKG=${1#$prefix} # Constants # # Save the file on the desktop (so it's easy to find) DEST_FILE=~/Desktop/depgraph.png # Calculate the package in the current directory and assume this is the base or project package BASE_PKG=$(go list) EXCLUSIONS="$BASE_PKG/vendor" BASE_PKG_DELIMITED=$(echo $BASE_PKG | sed 's/\//\\\//g') # Generate godepgraph -s \ -o "$BASE_PKG" \ -p "$EXCLUSIONS" \ $BASE_PKG/${PKG} | sed "s/$BASE_PKG_DELIMITED//g" | dot -Tpng -o $DEST_FILE # Open open $DEST_FILE ================================================ FILE: ch03/fake.go ================================================ package Hands_On_Dependency_Injection_in_Go func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch04/01_welcome/01_bad_names.go ================================================ package welcome type HouseV1 struct { a string b int t int p float64 } ================================================ FILE: ch04/01_welcome/02_improved_names.go ================================================ package welcome type HouseV2 struct { address string bedrooms int toilets int price float64 } ================================================ FILE: ch04/01_welcome/03_long_method.go ================================================ package welcome import ( "database/sql" "encoding/json" "net/http" "strconv" ) func longMethod(resp http.ResponseWriter, req *http.Request) { err := req.ParseForm() if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } userID, err := strconv.ParseInt(req.Form.Get("UserID"), 10, 64) if err != nil { resp.WriteHeader(http.StatusPreconditionFailed) return } row := DB.QueryRow("SELECT * FROM people WHERE ID = ?", userID) person := &Person{} err = row.Scan(&person.ID, &person.Name, &person.Phone) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } encoder := json.NewEncoder(resp) err = encoder.Encode(person) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } } var DB *sql.DB type Person struct { ID int64 Name string Phone string } ================================================ FILE: ch04/01_welcome/04_long_method_test.go ================================================ package welcome import ( "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLongMethod_happyPath(t *testing.T) { // build request request := &http.Request{} request.PostForm = url.Values{} request.PostForm.Add("UserID", "123") // mock the database var mockDB sqlmock.Sqlmock var err error DB, mockDB, err = sqlmock.New() require.NoError(t, err) mockDB.ExpectQuery("SELECT .* FROM people WHERE ID = ?"). WithArgs(123). WillReturnRows(sqlmock.NewRows([]string{"ID", "Name", "Phone"}).AddRow(123, "May", "0123456789")) // build response response := httptest.NewRecorder() // call method longMethod(response, request) // validate response require.Equal(t, http.StatusOK, response.Code) // validate the JSON responseBytes, err := ioutil.ReadAll(response.Body) require.NoError(t, err) expectedJSON := `{"ID":123,"Name":"May","Phone":"0123456789"}` + "\n" assert.Equal(t, expectedJSON, string(responseBytes)) } ================================================ FILE: ch04/01_welcome/05_short_methods.go ================================================ package welcome import ( "encoding/json" "net/http" "strconv" ) func shortMethods(resp http.ResponseWriter, req *http.Request) { userID, err := extractUserID(req) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } person, err := loadPerson(userID) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } outputPerson(resp, person) } func extractUserID(req *http.Request) (int64, error) { err := req.ParseForm() if err != nil { return 0, err } return strconv.ParseInt(req.Form.Get("UserID"), 10, 64) } func loadPerson(userID int64) (*Person, error) { row := DB.QueryRow("SELECT * FROM people WHERE ID = ?", userID) person := &Person{} err := row.Scan(&person.ID, &person.Name, &person.Phone) if err != nil { return nil, err } return person, nil } func outputPerson(resp http.ResponseWriter, person *Person) { encoder := json.NewEncoder(resp) err := encoder.Encode(person) if err != nil { resp.WriteHeader(http.StatusInternalServerError) return } } ================================================ FILE: ch04/03_known_issues/01_data_and_rest/get_example.go ================================================ //+build ignore package data_and_rest import ( "encoding/json" "io" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/01/data" ) // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(person) } ================================================ FILE: ch04/03_known_issues/02_config_coupling/config.go ================================================ package config_coupling import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/03_known_issues/02_config_coupling/currency" ) type Config struct { DefaultCurrency currency.Currency `json:"default_currency"` } ================================================ FILE: ch04/03_known_issues/02_config_coupling/currency/currency.go ================================================ package currency import ( "encoding/json" "fmt" ) // Currency is a custom type; used for convenience and code readability type Currency string // UnmarshalJSON implements json.Unmarshaler func (c *Currency) UnmarshalJSON(in []byte) error { var s string err := json.Unmarshal(in, &s) if err != nil { return err } currency, valid := validCurrencies[s] if !valid { return fmt.Errorf("'%s' is not a valid currency", s) } *c = currency return nil } const ( AUD = Currency("AUD") CNY = Currency("CNY") EUR = Currency("EUR") USD = Currency("USD") ) // a map of valid currencies var validCurrencies = map[string]Currency{ string(AUD): AUD, string(CNY): CNY, string(EUR): EUR, string(USD): USD, } ================================================ FILE: ch04/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // App is the application config var App *Config // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string } // Load returns the config loaded from environment func init() { filename, found := os.LookupEnv(DefaultEnvVar) if !found { logging.L.Error("failed to locate file specified by %s", DefaultEnvVar) return } _ = load(filename) } func load(filename string) error { App = &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return err } err = json.Unmarshal(bytes, App) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return err } return nil } ================================================ FILE: ch04/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: &Config{}, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, App, scenario.desc) }) } } ================================================ FILE: ch04/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch04/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/logging" // import the MySQL Driver _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) func getDB() (*sql.DB, error) { if db == nil { if config.App == nil { return nil, errors.New("config is not initialized") } var err error db, err = sql.Open("mysql", config.App.DSN) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func Save(in *Person) (int, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // perform DB insert query := "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" result, err := db.Exec(query, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { logging.L.Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { logging.L.Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func LoadAll() ([]*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // perform DB select query := "SELECT id, fullname, phone, currency, price FROM person" rows, err := db.Query(query) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { logging.L.Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { logging.L.Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func Load(ID int) (*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // perform DB select query := "SELECT id, fullname, phone, currency, price FROM person WHERE id = ? LIMIT 1" row := db.QueryRow(query, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { logging.L.Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } logging.L.Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch04/acme/internal/modules/data/data_test.go ================================================ package data import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestData_happyPath(t *testing.T) { in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // save resultID, err := Save(in) require.Nil(t, err) assert.True(t, resultID > 0) // load returned, err := Load(resultID) require.NoError(t, err) in.ID = resultID assert.Equal(t, in, returned) // load all all, err := LoadAll() require.NoError(t, err) assert.True(t, len(all) > 0) } ================================================ FILE: ch04/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "encoding/json" "fmt" "io/ioutil" "math" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct{} // Do will perform the conversion func (c *Converter) Do(basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, config.App.ExchangeRateBaseURL, config.App.ExchangeRateAPIKey, currency) // perform request response, err := http.Get(url) if err != nil { logging.L.Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) logging.L.Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) logging.L.Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { logging.L.Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { logging.L.Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch04/acme/internal/modules/get/get.go ================================================ package get import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := data.Load(ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } ================================================ FILE: ch04/acme/internal/modules/get/go_test.go ================================================ package get import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetter_Do(t *testing.T) { // inputs ID := 1 name := "John" // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, name, person.FullName) } ================================================ FILE: ch04/acme/internal/modules/list/list.go ================================================ package list import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { } // Do will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := data.LoadAll() if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } ================================================ FILE: ch04/acme/internal/modules/list/list_test.go ================================================ package list import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLister_Do(t *testing.T) { // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.NoError(t, err) assert.True(t, len(persons) >= 4) } ================================================ FILE: ch04/acme/internal/modules/register/register.go ================================================ package register import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/exchange" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { } // Do is API for this struct func (r *Registerer) Do(in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { logging.L.Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(currency string) (float64, error) { converter := &exchange.Converter{} price, err := converter.Do(config.App.BasePrice, currency) if err != nil { logging.L.Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return data.Save(person) } ================================================ FILE: ch04/acme/internal/modules/register/register_test.go ================================================ package register import ( "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterer_Do(t *testing.T) { // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233345", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(in) // validate expectations require.NoError(t, err) assert.True(t, ID > 0) } ================================================ FILE: ch04/acme/internal/rest/common_test.go ================================================ package rest import ( "context" "net" ) func getOpenPort() (string, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return "", err } address := listener.Addr().String() listener.Close() return address, nil } func startServer(ctx context.Context) (string, error) { // get open port address, err := getOpenPort() if err != nil { return "", err } // start a server server := New(address) go server.Listen(ctx.Done()) // wait for server to be ready dialer := &net.Dialer{} for { conn, _ := dialer.DialContext(ctx, "tcp", address) if conn != nil { defer conn.Close() return address, nil } select { case <-ctx.Done(): return "", ctx.Err() default: // try again } } return address, nil } ================================================ FILE: ch04/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/get" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 ) // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get getter := get.Getter{} person, err := getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars["id"] if !exists { // log and return error err := errors.New("[get] person id missing from request") logging.L.Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) logging.L.Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch04/acme/internal/rest/get_test.go ================================================ package rest import ( "context" "io/ioutil" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs response, err := http.Get("http://" + address + "/person/1/") // validate outputs require.NoError(t, err) require.Equal(t, http.StatusOK, response.StatusCode) expectedPayload := []byte(`{"id":1,"name":"John","phone":"0123456780","currency":"USD","price":100}` + "\n") payload, _ := ioutil.ReadAll(response.Body) defer response.Body.Close() assert.Equal(t, expectedPayload, payload) } ================================================ FILE: ch04/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/list" ) // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll lister := list.Lister{} people, err := lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch04/acme/internal/rest/list_test.go ================================================ package rest import ( "context" "io/ioutil" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs response, err := http.Get("http://" + address + "/person/list") // validate outputs require.NoError(t, err) require.Equal(t, http.StatusOK, response.StatusCode) expectedPayload := []byte(`{"people":[{"id":1,"name":"John","phone":"0123456780"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":4,"name":"Ringo","phone":"0123456783"}`) payload, _ := ioutil.ReadAll(response.Body) defer response.Body.Close() // we have to use contains because other tests add more records assert.Contains(t, string(payload), string(expectedPayload)) } ================================================ FILE: ch04/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch04/acme/internal/rest/not_found_test.go ================================================ package rest import ( "context" "net/http" "testing" "time" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs response, err := http.Get("http://" + address + "/some-bad-address") // validate outputs require.NoError(t, err) require.Equal(t, http.StatusNotFound, response.StatusCode) } ================================================ FILE: ch04/acme/internal/rest/register.go ================================================ package rest import ( "encoding/json" "fmt" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/modules/register" ) // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // register person id, err := h.register(requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } registerer := ®ister.Registerer{} return registerer.Do(person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch04/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "context" "encoding/json" "io" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs validRequest := buildValidRequest() response, err := http.Post("http://"+address+"/person/register", "application/json", validRequest) // validate outputs require.NoError(t, err) require.Equal(t, http.StatusCreated, response.StatusCode) defer response.Body.Close() // call should output the location to the new person headerLocation := response.Header.Get("Location") assert.Contains(t, headerLocation, "/person/") } func buildValidRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch04/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/gorilla/mux" ) // New will create and initialize the server func New(address string) *Server { return &Server{ address: address, handlerGet: &GetHandler{}, handlerList: &ListHandler{}, handlerNotFound: notFoundHandler, handlerRegister: &RegisterHandler{}, } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch04/acme/main.go ================================================ package main import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch04/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server := rest.New(config.App.Address) server.Listen(ctx.Done()) } ================================================ FILE: ch04/fake.go ================================================ package ch04 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch05/02_advantages/01_function.go ================================================ package advantages import ( "encoding/json" "io/ioutil" "log" ) func SaveConfig(filename string, cfg *Config) error { // convert to JSON data, err := json.Marshal(cfg) if err != nil { return err } // save file err = ioutil.WriteFile(filename, data, 0666) if err != nil { log.Printf("failed to save file '%s' with err: %s", filename, err) return err } return nil } type Config struct { Host string Port int } ================================================ FILE: ch05/02_advantages/02_monkey_patched.go ================================================ package advantages import ( "encoding/json" "fmt" "io/ioutil" "log" ) func SaveConfigPatched(filename string, cfg *Config) error { // convert to JSON data, err := json.Marshal(cfg) if err != nil { return err } // save file err = writeFile(filename, data, 0666) if err != nil { log.Printf("failed to save file '%s' with err: %s", filename, err) return err } return nil } // Custom type that allows us to Monkey Patch var writeFile = ioutil.WriteFile // Usage func SaveConfigPatchedUsage() { cfg := &Config{ // build the config } err := SaveConfigPatched("myfile.json", cfg) if err != nil { fmt.Printf("failed with err: %s", err) } } ================================================ FILE: ch05/02_advantages/03_injected_lambda.go ================================================ package advantages import ( "encoding/json" "fmt" "io/ioutil" "log" "os" ) func SaveConfigInjected(writer fileWriter, filename string, cfg *Config) error { // convert to JSON data, err := json.Marshal(cfg) if err != nil { return err } // save file err = writer(filename, data, 0666) if err != nil { log.Printf("failed to save file '%s' with err: %s", filename, err) return err } return nil } // This custom type is not strictly needed but it does make the function // signature a little cleaner type fileWriter func(filename string, data []byte, perm os.FileMode) error // Usage func SaveConfigInjectedUsage() { cfg := &Config{ // build the config } err := SaveConfigInjected(ioutil.WriteFile, "myfile.json", cfg) if err != nil { fmt.Printf("failed with err: %s", err) } } ================================================ FILE: ch05/02_advantages/04_as_object.go ================================================ package advantages import ( "encoding/json" "fmt" "io/ioutil" "log" "os" ) type ConfigSaver struct { FileWriter func(filename string, data []byte, perm os.FileMode) error } func (c ConfigSaver) Save(filename string, cfg *Config) error { // convert to JSON data, err := json.Marshal(cfg) if err != nil { return err } // save file err = c.FileWriter(filename, data, 0666) if err != nil { log.Printf("failed to save file '%s' with err: %s", filename, err) return err } return nil } // Usage func ConfigSaverUsage() { cfg := &Config{ // build the config } saver := &ConfigSaver{ FileWriter: ioutil.WriteFile, } err := saver.Save("myfile.json", cfg) if err != nil { fmt.Printf("failed with err: %s", err) } } ================================================ FILE: ch05/02_advantages/05_math_rand.go ================================================ package advantages // A Rand is a source of random numbers. type Rand struct { src Source // code removed } // Int returns a non-negative pseudo-random int. func (r *Rand) Int() int { // code changed for brevity value := r.src.Int63() return int(value) } /* * Top-level convenience functions */ var globalRand = New(&lockedSource{}) // Int returns a non-negative pseudo-random int from the default Source. func Int() int { return globalRand.Int() } /* * Code below here has been modified so that it compiles but does nothing. * The original code is: https://golang.org/src/math/rand/rand.go */ // New returns a new Rand that uses random values from src // to generate other random values. func New(src Source) *Rand { // code changed for brevity return &Rand{ src: src, } } type lockedSource struct { // code removed } func (l *lockedSource) Int63() int64 { // code removed return 0 } // A Source represents a source of uniformly-distributed // pseudo-random int64 values in the range [0, 1<<63). type Source interface { Int63() int64 // code removed } ================================================ FILE: ch05/02_advantages/06_math_rand_test.go ================================================ package advantages import ( "testing" "github.com/stretchr/testify/assert" ) func TestInt(t *testing.T) { // monkey patch defer func(original *Rand) { // restore patch after use globalRand = original }(globalRand) // swap out for a predictable outcome globalRand = New(&stubSource{}) // end monkey patch // call the function result := Int() assert.Equal(t, 234, result) } // this is a stubbed implementation of Source that returns a predictable value type stubSource struct { } func (s *stubSource) Int63() int64 { return 234 } ================================================ FILE: ch05/03_applying/01_simple_sqlmock_test.go ================================================ package applying import ( "database/sql" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // clean up afterwards defer testDb.Close() // define the query we are expecting as regular expression queryRegex := `\QINSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)\E` // configure the mock db dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // inputs person := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := SavePerson(testDb, person) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func SavePerson(db *sql.DB, in *Person) (int, error) { // perform DB insert query := "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" result, err := db.Exec(query, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { return 0, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { return 0, err } return int(id), nil } ================================================ FILE: ch05/03_applying/02_load.go ================================================ package applying import ( "database/sql" "strings" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) func TestLoad_happyPath(t *testing.T) { expectedResult := &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, } // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlLoadByID) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) // monkey patching the database defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) db = testDb // end of monkey patch // call function result, err := Load(2) // validate results assert.Equal(t, expectedResult, result) assert.NoError(t, err) assert.NoError(t, dbMock.ExpectationsWereMet()) } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func Load(ID int) (*Person, error) { // code removed/faked for brevity return &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, nil } // code removed for brevity var db = &sql.DB{} // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } ================================================ FILE: ch05/04_disadvantages/01_verbose.go ================================================ package disadvantages import ( "encoding/json" "io/ioutil" "log" ) func SaveConfig(filename string, cfg *Config) error { // convert to JSON data, err := json.Marshal(cfg) if err != nil { return err } // save file err = writeFile(filename, data, 0666) if err != nil { log.Printf("failed to save file '%s' with err: %s", filename, err) return err } return nil } // Custom type that allows var writeFile = ioutil.WriteFile type Config struct { Host string Port int } ================================================ FILE: ch05/04_disadvantages/02_verbose_test.go ================================================ package disadvantages import ( "os" "testing" "github.com/stretchr/testify/assert" ) func TestSaveConfig(t *testing.T) { // inputs filename := "my-config.json" cfg := &Config{ Host: "localhost", Port: 1234, } // monkey patch the file writer defer func(original func(filename string, data []byte, perm os.FileMode) error) { // restore the original writeFile = original }(writeFile) writeFile = func(filename string, data []byte, perm os.FileMode) error { // output error return nil } // call the function err := SaveConfig(filename, cfg) // validate the result assert.NoError(t, err) } ================================================ FILE: ch05/04_disadvantages/03_refactored_test.go ================================================ package disadvantages import ( "os" "testing" "github.com/stretchr/testify/assert" ) func TestSaveConfig_refactored(t *testing.T) { // inputs filename := "my-config.json" cfg := &Config{ Host: "localhost", Port: 1234, } // monkey patch the file writer defer restoreWriteFile(writeFile) writeFile = mockWriteFile(nil) // call the function err := SaveConfig(filename, cfg) // validate the result assert.NoError(t, err) } func mockWriteFile(result error) func(filename string, data []byte, perm os.FileMode) error { return func(filename string, data []byte, perm os.FileMode) error { return result } } // remove the restore function to reduce from 3 lines to 1 func restoreWriteFile(original func(filename string, data []byte, perm os.FileMode) error) { // restore the original writeFile = original } ================================================ FILE: ch05/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // App is the application config var App *Config // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string } // Load returns the config loaded from environment func init() { filename, found := os.LookupEnv(DefaultEnvVar) if !found { logging.L.Error("failed to locate file specified by %s", DefaultEnvVar) return } _ = load(filename) } func load(filename string) error { App = &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return err } err = json.Unmarshal(bytes, App) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return err } return nil } ================================================ FILE: ch05/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: &Config{}, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, App, scenario.desc) }) } } ================================================ FILE: ch05/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch05/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/logging" // import the MySQL Driver _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) var getDB = func() (*sql.DB, error) { if db == nil { if config.App == nil { return nil, errors.New("config is not initialized") } var err error db, err = sql.Open("mysql", config.App.DSN) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func Save(in *Person) (int, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // perform DB insert result, err := db.Exec(sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { logging.L.Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { logging.L.Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func LoadAll() ([]*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // perform DB select rows, err := db.Query(sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { logging.L.Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { logging.L.Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func Load(ID int) (*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // perform DB select row := db.QueryRow(sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { logging.L.Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } logging.L.Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } func init() { // ensure the config is loaded and the db initialized _, _ = getDB() } ================================================ FILE: ch05/acme/internal/modules/data/data_test.go ================================================ package data import ( "database/sql" "errors" "strings" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) // replace db for this test db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) // replace db for this test db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // monkey patching starts here defer func(original func() (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func() (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function results, err := LoadAll() // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function result, err := Load(2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } ================================================ FILE: ch05/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "encoding/json" "fmt" "io/ioutil" "math" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct{} // Do will perform the conversion func (c *Converter) Do(basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, config.App.ExchangeRateBaseURL, config.App.ExchangeRateAPIKey, currency) // perform request response, err := http.Get(url) if err != nil { logging.L.Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) logging.L.Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) logging.L.Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { logging.L.Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { logging.L.Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch05/acme/internal/modules/get/get.go ================================================ package get import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := loader(ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } // this function as a variable allows us to Monkey Patch during testing var loader = data.Load ================================================ FILE: ch05/acme/internal/modules/get/go_test.go ================================================ package get import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(ID int) (*data.Person, error) { result := &data.Person{ ID: 1234, FullName: "Doug", } var resultErr error return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // monkey patch calls to the data package defer func(original func(ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(ID int) (*data.Person, error) { var result *data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(ID int) (*data.Person, error) { var result *data.Person resultErr := errors.New("failed to load person") return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) } ================================================ FILE: ch05/acme/internal/modules/list/list.go ================================================ package list import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { } // Do will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := loader() if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } // this function as a variable allows us to Monkey Patch during testing var loader = data.LoadAll ================================================ FILE: ch05/acme/internal/modules/list/list_test.go ================================================ package list import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // monkey patch calls to the data package defer func(original func() ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func() ([]*data.Person, error) { result := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } var resultErr error return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) } func TestLister_Do_noResults(t *testing.T) { // monkey patch calls to the data package defer func(original func() ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func() ([]*data.Person, error) { var result []*data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) } func TestLister_Do_error(t *testing.T) { // monkey patch calls to the data package defer func(original func() ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func() ([]*data.Person, error) { var result []*data.Person resultErr := errors.New("failed to load people") return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) } ================================================ FILE: ch05/acme/internal/modules/register/register.go ================================================ package register import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/exchange" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { } // Do is API for this struct func (r *Registerer) Do(in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { logging.L.Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(currency string) (float64, error) { converter := &exchange.Converter{} price, err := converter.Do(config.App.BasePrice, currency) if err != nil { logging.L.Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return saver(person) } // this function as a variable allows us to Monkey Patch during testing var saver = data.Save ================================================ FILE: ch05/acme/internal/modules/register/register_test.go ================================================ package register import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // monkey patch calls to the data package defer func(original func(in *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(in *data.Person) (int, error) { result := 888 var resultErr error return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) } func TestRegisterer_Do_error(t *testing.T) { // monkey patch calls to the data package defer func(original func(in *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(in *data.Person) (int, error) { var result int resultErr := errors.New("failed to save") return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) } ================================================ FILE: ch05/acme/internal/rest/common_test.go ================================================ package rest import ( "context" "net" ) func getOpenPort() (string, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return "", err } address := listener.Addr().String() listener.Close() return address, nil } func startServer(ctx context.Context) (string, error) { // get open port address, err := getOpenPort() if err != nil { return "", err } // start a server server := New(address) go server.Listen(ctx.Done()) // wait for server to be ready dialer := &net.Dialer{} for { conn, _ := dialer.DialContext(ctx, "tcp", address) if conn != nil { defer conn.Close() return address, nil } select { case <-ctx.Done(): return "", ctx.Err() default: // try again } } return address, nil } ================================================ FILE: ch05/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/get" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 ) // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get getter := get.Getter{} person, err := getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars["id"] if !exists { // log and return error err := errors.New("[get] person id missing from request") logging.L.Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) logging.L.Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch05/acme/internal/rest/get_test.go ================================================ package rest import ( "context" "io/ioutil" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs response, err := http.Get("http://" + address + "/person/1/") // validate outputs require.NoError(t, err) require.Equal(t, http.StatusOK, response.StatusCode) expectedPayload := []byte(`{"id":1,"name":"John","phone":"0123456780","currency":"USD","price":100}` + "\n") payload, _ := ioutil.ReadAll(response.Body) defer response.Body.Close() assert.Equal(t, expectedPayload, payload) } ================================================ FILE: ch05/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/list" ) // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll lister := list.Lister{} people, err := lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch05/acme/internal/rest/list_test.go ================================================ package rest import ( "context" "io/ioutil" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs response, err := http.Get("http://" + address + "/person/list") // validate outputs require.NoError(t, err) require.Equal(t, http.StatusOK, response.StatusCode) expectedPayload := []byte(`{"people":[{"id":1,"name":"John","phone":"0123456780"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":4,"name":"Ringo","phone":"0123456783"}`) payload, _ := ioutil.ReadAll(response.Body) defer response.Body.Close() // we have to use contains because other tests add more records assert.Contains(t, string(payload), string(expectedPayload)) } ================================================ FILE: ch05/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch05/acme/internal/rest/not_found_test.go ================================================ package rest import ( "context" "net/http" "testing" "time" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs response, err := http.Get("http://" + address + "/some-bad-address") // validate outputs require.NoError(t, err) require.Equal(t, http.StatusNotFound, response.StatusCode) } ================================================ FILE: ch05/acme/internal/rest/register.go ================================================ package rest import ( "encoding/json" "fmt" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/modules/register" ) // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // register person id, err := h.register(requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } registerer := ®ister.Registerer{} return registerer.Do(person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch05/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "context" "encoding/json" "io" "net/http" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { // ensure the test always fails by giving it a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Create and start a server // With out current implementation, we cannot test this handler without a full server as we need the mux. address, err := startServer(ctx) require.NoError(t, err) // build inputs validRequest := buildValidRequest() response, err := http.Post("http://"+address+"/person/register", "application/json", validRequest) // validate outputs require.NoError(t, err) require.Equal(t, http.StatusCreated, response.StatusCode) defer response.Body.Close() // call should output the location to the new person headerLocation := response.Header.Get("Location") assert.Contains(t, headerLocation, "/person/") } func buildValidRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch05/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/gorilla/mux" ) // New will create and initialize the server func New(address string) *Server { return &Server{ address: address, handlerGet: &GetHandler{}, handlerList: &ListHandler{}, handlerNotFound: notFoundHandler, handlerRegister: &RegisterHandler{}, } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch05/acme/main.go ================================================ package main import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch05/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server := rest.New(config.App.Address) server.Listen(ctx.Done()) } ================================================ FILE: ch05/fake.go ================================================ package ch05 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch06/01_constructor_injection/01_welcome_email.go ================================================ package constructor_injection import ( "errors" ) func NewWelcomeSender(in *Mailer) (*WelcomeSender, error) { // guard clause if in == nil { return nil, errors.New("programmer error: mailer must not provided") } return &WelcomeSender{ mailer: in, }, nil } func NewWelcomeSenderNoGuard(in *Mailer) *WelcomeSender { return &WelcomeSender{ mailer: in, } } // WelcomeSender sends a Welcome email to new users type WelcomeSender struct { mailer *Mailer } func (w *WelcomeSender) Send(to string) error { body := w.buildMessage() return w.mailer.Send(to, body) } // build and return the message body func (w *WelcomeSender) buildMessage() string { return "" } // Mailer sends and receives emails type Mailer struct { Host string Port string Username string Password string } func (m *Mailer) Send(to string, body string) error { // send email return nil } func (m *Mailer) Receive(address string) (string, error) { // receive email return "", nil } ================================================ FILE: ch06/01_constructor_injection/01_welcome_email_test.go ================================================ package constructor_injection import ( "testing" "github.com/stretchr/testify/assert" ) func TestNewWelcomeSender_happyPath(t *testing.T) { sender, err := NewWelcomeSender(&Mailer{}) assert.NotNil(t, sender) assert.NoError(t, err) } func TestNewWelcomeSender_guardClause(t *testing.T) { sender, err := NewWelcomeSender(nil) assert.Nil(t, sender) assert.Error(t, err) } func TestNewWelcomeSenderNoGuard_happyPath(t *testing.T) { sender := NewWelcomeSenderNoGuard(&Mailer{}) assert.NotNil(t, sender) } ================================================ FILE: ch06/01_constructor_injection/02_mailer_interface.go ================================================ package constructor_injection // Mailer sends and receives emails type MailerInterface interface { Send(to string, body string) error Receive(address string) (string, error) } ================================================ FILE: ch06/01_constructor_injection/03_sender_interface.go ================================================ package constructor_injection type Sender interface { Send(to string, body string) error } func NewWelcomeSenderV2(in Sender) *WelcomeSenderV2 { return &WelcomeSenderV2{ sender: in, } } // WelcomeSenderV2 sends a Welcome email to new users type WelcomeSenderV2 struct { sender Sender } func (w *WelcomeSenderV2) Send(to string) error { body := w.buildMessage() return w.sender.Send(to, body) } // build and return the message body func (w *WelcomeSenderV2) buildMessage() string { return "" } ================================================ FILE: ch06/01_constructor_injection/05_duck_typing.go ================================================ package constructor_injection import ( "fmt" ) type Talker interface { Speak() string Shout() string } type Dog struct{} func (d Dog) Speak() string { return "Woof!" } func (d Dog) Shout() string { return "WOOF!" } func SpeakExample() { var talker Talker talker = Dog{} fmt.Print(talker.Speak()) } ================================================ FILE: ch06/02_advantages/01_easy_to_implement.go ================================================ package advantages // WelcomeSender sends a Welcome email to new users type WelcomeSender struct { Mailer *Mailer } func (w *WelcomeSender) Send(to string) error { body := w.buildMessage() return w.Mailer.Send(to, body) } // build and return the message body func (w *WelcomeSender) buildMessage() string { return "" } // Mailer will send an email type Mailer struct{} func (m *Mailer) Send(to string, body string) error { // send email return nil } ================================================ FILE: ch06/02_advantages/01_easy_to_implement_example_test.go ================================================ package advantages_test import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/02_advantages" ) func ExampleWelcomeSender_Send() { welcomeSender := &advantages.WelcomeSender{ Mailer: &advantages.Mailer{}, } welcomeSender.Send("me@home.com") } ================================================ FILE: ch06/02_advantages/02_easy_to_implement.go ================================================ package advantages func NewWelcomeSenderV2(mailer *Mailer) *WelcomeSenderV2 { return &WelcomeSenderV2{ mailer: mailer, } } // WelcomeSenderV2 sends a Welcome email to new users type WelcomeSenderV2 struct { mailer *Mailer } func (w *WelcomeSenderV2) Send(to string) error { body := w.buildMessage() return w.mailer.Send(to, body) } // build and return the message body func (w *WelcomeSenderV2) buildMessage() string { return "" } ================================================ FILE: ch06/02_advantages/02_easy_to_implement_example_test.go ================================================ package advantages_test import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/02_advantages" ) func ExampleWelcomeSenderV2_Send() { welcomeSender := advantages.NewWelcomeSenderV2(&advantages.Mailer{}) welcomeSender.Send("me@home.com") } ================================================ FILE: ch06/02_advantages/03_predictable.go ================================================ package advantages import ( "errors" ) type Engine interface { Start() IncreasePower() DecreasePower() Stop() IsRunning() bool } type Car struct { Engine Engine } func (c *Car) Drive() error { if c.Engine == nil { return errors.New("engine ie missing") } // use the engine c.Engine.Start() c.Engine.IncreasePower() return nil } func (c *Car) Stop() error { if c.Engine == nil { return errors.New("engine ie missing") } // use the engine c.Engine.DecreasePower() c.Engine.Stop() return nil } ================================================ FILE: ch06/02_advantages/04_predictable.go ================================================ package advantages import ( "errors" ) func NewCarV2(engine Engine) (*CarV2, error) { if engine == nil { return nil, errors.New("invalid engine supplied") } return &CarV2{ engine: engine, }, nil } type CarV2 struct { engine Engine } func (c *CarV2) Drive() error { // use the engine c.engine.Start() c.engine.IncreasePower() return nil } func (c *CarV2) Stop() error { // use the engine c.engine.DecreasePower() c.engine.Stop() return nil } ================================================ FILE: ch06/02_advantages/05_encapsulation.go ================================================ package advantages import ( "errors" ) func (c *CarV2) FillPetrolTank() error { // use the engine if c.engine.IsRunning() { return errors.New("cannot fill the tank while the engine is running") } // fill the tank! return c.fill() } func (c CarV2) fill() error { // TODO: implement return nil } ================================================ FILE: ch06/02_advantages/06_encapsulation.go ================================================ package advantages import ( "errors" ) func (c *CarV2) FillPetrolTankV2(engine Engine) error { // use the engine if engine.IsRunning() { return errors.New("cannot fill the tank while the engine is running") } // fill the tank! return c.fill() } ================================================ FILE: ch06/03_applying/01/01_register_handler_before.go ================================================ package rest import ( "encoding/json" "fmt" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/01/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/01/register" ) // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // register person id, err := h.register(requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } registerer := ®ister.Registerer{} return registerer.Do(person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch06/03_applying/01/data/person.go ================================================ package data // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } ================================================ FILE: ch06/03_applying/01/register/register.go ================================================ package register import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/01/data" ) // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { } // Do is API for this struct func (r *Registerer) Do(in *data.Person) (int, error) { // fake implementation return 0, nil } ================================================ FILE: ch06/03_applying/02/01_register_handler.go ================================================ package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/02/register" ) // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer *register.Registerer } ================================================ FILE: ch06/03_applying/02/data/person.go ================================================ package data // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } ================================================ FILE: ch06/03_applying/02/register/register.go ================================================ package register import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/01/data" ) // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { } // Do is API for this struct func (r *Registerer) Do(in *data.Person) (int, error) { // fake implementation return 0, nil } ================================================ FILE: ch06/03_applying/03/data/person.go ================================================ package data // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } ================================================ FILE: ch06/03_applying/03/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/03/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: in func (_m *MockRegisterModel) Do(in *data.Person) (int, error) { ret := _m.Called(in) var r0 int if rf, ok := ret.Get(0).(func(*data.Person) int); ok { r0 = rf(in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(*data.Person) error); ok { r1 = rf(in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch06/03_applying/03/register_test.go ================================================ package rest import ( "net/http" "testing" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ // scenarios go here } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // test goes here }) } } ================================================ FILE: ch06/03_applying/04/data/person.go ================================================ package data // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } ================================================ FILE: ch06/03_applying/04/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/04/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: in func (_m *MockRegisterModel) Do(in *data.Person) (int, error) { ret := _m.Called(in) var r0 int if rf, ok := ret.Get(0).(func(*data.Person) int); ok { r0 = rf(in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(*data.Person) error); ok { r1 = rf(in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch06/03_applying/04/register.go ================================================ package rest import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/04/data" ) // RegisterModel will validate and save a registration type RegisterModel interface { Do(in *data.Person) (int, error) } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // implementation goes here } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch06/03_applying/04/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ // scenarios go here } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := &RegisterHandler{ registerer: mockRegisterModel, } // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch06/03_applying/05/data/person.go ================================================ package data // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } ================================================ FILE: ch06/03_applying/05/fakes.go ================================================ package reset import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/data" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } // Fake/Stub implementations to make the compiler happy type Server struct { address string handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } func NewGetHandler(_ GetModel) *GetHandler { return &GetHandler{} } type GetModel interface { Do(ID int) (*data.Person, error) } type GetHandler struct{} func (g *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) {} func NewListHandler(_ ListModel) *ListHandler { return &ListHandler{} } type ListModel interface { Do() ([]*data.Person, error) } type ListHandler struct { } func (l *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) {} func NewRegisterHandler(_ RegisterModel) *RegisterHandler { return &RegisterHandler{} } type RegisterModel interface { Do(in *data.Person) (int, error) } type RegisterHandler struct { } func (r *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) {} ================================================ FILE: ch06/03_applying/05/get/getter.go ================================================ package get import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/data" ) // Stub implementation so that the example compiles type Getter struct{} func (g *Getter) Do(ID int) (*data.Person, error) { return nil, nil } ================================================ FILE: ch06/03_applying/05/list/lister.go ================================================ package list import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/data" ) // Stub implementation so that the example compiles type Lister struct{} func (l *Lister) Do() ([]*data.Person, error) { return nil, nil } ================================================ FILE: ch06/03_applying/05/register/registerer.go ================================================ package register import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/data" ) // Stub implementation so that the example compiles type Registerer struct{} func (r *Registerer) Do(in *data.Person) (int, error) { return 0, nil } ================================================ FILE: ch06/03_applying/05/server.go ================================================ package reset import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/03_applying/05/register" ) // New will create and initialize the server func New(address string) *Server { return &Server{ address: address, handlerGet: NewGetHandler(&get.Getter{}), handlerList: NewListHandler(&list.Lister{}), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(®ister.Registerer{}), } } ================================================ FILE: ch06/04_disadvantages/01_lots_of_changes.go ================================================ package disadvantages // Dealer will shuffle a deck of cards and deal them to the players func DealCards() (player1 []Card, player2 []Card) { // create a new deck of cards cards := newDeck() // shuffle the cards shuffler := &myShuffler{} shuffler.Shuffle(cards) // deal player1 = append(player1, cards[0]) player2 = append(player2, cards[1]) player1 = append(player1, cards[2]) player2 = append(player2, cards[3]) return } // returns a new deck of cards func newDeck() []Card { return []Card{ // code removed } } // Shuffler will shuffle (randomize) the supplied cards type Shuffler interface { Shuffle(cards []Card) } // Card is single Playing Card type Card struct { Suit string Value string } // implements Shuffler type myShuffler struct{} // Shuffle implements shuffler func (m *myShuffler) Shuffle(cards []Card) { // randomize the cards } ================================================ FILE: ch06/04_disadvantages/02_overuse.go ================================================ package disadvantages import ( "encoding/json" "io/ioutil" "net/http" ) const downstreamServer = "http://www.example.com" // FetchRates rates from downstream service type FetchRates struct{} func (f *FetchRates) Fetch() ([]Rate, error) { // build the URL from which to fetch the rates url := downstreamServer + "/rates" // build request request, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } // fetch rates response, err := http.DefaultClient.Do(request) if err != nil { return nil, err } defer response.Body.Close() // read the content of the response data, err := ioutil.ReadAll(response.Body) if err != nil { return nil, err } // convert JSON bytes to Go structs out := &downstreamResponse{} err = json.Unmarshal(data, out) if err != nil { return nil, err } return out.Rates, nil } // response format from the downstream service type downstreamResponse struct { Rates []Rate `json:"rates"` } type Rate struct { Code string Value float64 } ================================================ FILE: ch06/04_disadvantages/03_non_obvious.go ================================================ package disadvantages import ( "errors" ) // NewClient creates and initialises the client func NewClient(service DepService) Client { return &clientImpl{ service: service, } } // Client is the exported API type Client interface { DoSomethingUseful() (bool, error) } // implement Client type clientImpl struct { service DepService } func (c *clientImpl) DoSomethingUseful() (bool, error) { // this function does something useful return false, errors.New("not implemented") } type DepService interface { DoSomethingElse() } ================================================ FILE: ch06/04_disadvantages/04_non_obvious_example_test.go ================================================ package disadvantages_test func Example() { } // StubClient is a stub implementation of disadvantages.Client interface type StubClient struct{} // DoSomethingUseful implements disadvantages.Client func (s *StubClient) DoSomethingUseful() (bool, error) { return true, nil } ================================================ FILE: ch06/04_disadvantages/05_constructors.go ================================================ package disadvantages type InnerService struct { innerDep Dependency } func NewInnerService(innerDep Dependency) *InnerService { return &InnerService{ innerDep: innerDep, } } type OuterService struct { // composition innerService *InnerService outerDep Dependency } func NewOuterService(outerDep Dependency, innerDep Dependency) *OuterService { return &OuterService{ innerService: NewInnerService(innerDep), outerDep: outerDep, } } // fake type to satisfy the compiler type Dependency interface { } ================================================ FILE: ch06/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // App is the application config var App *Config // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string } // Load returns the config loaded from environment func init() { filename, found := os.LookupEnv(DefaultEnvVar) if !found { logging.L.Error("failed to locate file specified by %s", DefaultEnvVar) return } _ = load(filename) } func load(filename string) error { App = &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return err } err = json.Unmarshal(bytes, App) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return err } return nil } ================================================ FILE: ch06/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: &Config{}, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, App, scenario.desc) }) } } ================================================ FILE: ch06/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch06/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/logging" // import the MySQL Driver _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) var getDB = func() (*sql.DB, error) { if db == nil { if config.App == nil { return nil, errors.New("config is not initialized") } var err error db, err = sql.Open("mysql", config.App.DSN) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func Save(in *Person) (int, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // perform DB insert result, err := db.Exec(sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { logging.L.Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { logging.L.Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func LoadAll() ([]*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // perform DB select rows, err := db.Query(sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { logging.L.Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { logging.L.Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func Load(ID int) (*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // perform DB select row := db.QueryRow(sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { logging.L.Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } logging.L.Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } func init() { // ensure the config is loaded and the db initialized _, _ = getDB() } ================================================ FILE: ch06/acme/internal/modules/data/data_test.go ================================================ package data import ( "database/sql" "errors" "strings" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) // replace db for this test db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) // replace db for this test db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // monkey patching starts here defer func(original func() (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func() (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function results, err := LoadAll() // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function result, err := Load(2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } ================================================ FILE: ch06/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "encoding/json" "fmt" "io/ioutil" "math" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct{} // Do will perform the conversion func (c *Converter) Do(basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, config.App.ExchangeRateBaseURL, config.App.ExchangeRateAPIKey, currency) // perform request response, err := http.Get(url) if err != nil { logging.L.Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) logging.L.Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) logging.L.Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { logging.L.Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { logging.L.Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch06/acme/internal/modules/get/get.go ================================================ package get import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := loader(ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } // this function as a variable allows us to Monkey Patch during testing var loader = data.Load ================================================ FILE: ch06/acme/internal/modules/get/go_test.go ================================================ package get import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(ID int) (*data.Person, error) { result := &data.Person{ ID: 1234, FullName: "Doug", } var resultErr error return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // monkey patch calls to the data package defer func(original func(ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(ID int) (*data.Person, error) { var result *data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(ID int) (*data.Person, error) { var result *data.Person resultErr := errors.New("failed to load person") return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) } ================================================ FILE: ch06/acme/internal/modules/list/list.go ================================================ package list import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { } // Do will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := loader() if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } // this function as a variable allows us to Monkey Patch during testing var loader = data.LoadAll ================================================ FILE: ch06/acme/internal/modules/list/list_test.go ================================================ package list import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // monkey patch calls to the data package defer func(original func() ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func() ([]*data.Person, error) { result := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } var resultErr error return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) } func TestLister_Do_noResults(t *testing.T) { // monkey patch calls to the data package defer func(original func() ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func() ([]*data.Person, error) { var result []*data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) } func TestLister_Do_error(t *testing.T) { // monkey patch calls to the data package defer func(original func() ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func() ([]*data.Person, error) { var result []*data.Person resultErr := errors.New("failed to load people") return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) } ================================================ FILE: ch06/acme/internal/modules/register/register.go ================================================ package register import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/exchange" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { } // Do is API for this struct func (r *Registerer) Do(in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { logging.L.Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(currency string) (float64, error) { converter := &exchange.Converter{} price, err := converter.Do(config.App.BasePrice, currency) if err != nil { logging.L.Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return saver(person) } // this function as a variable allows us to Monkey Patch during testing var saver = data.Save ================================================ FILE: ch06/acme/internal/modules/register/register_test.go ================================================ package register import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // monkey patch calls to the data package defer func(original func(in *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(in *data.Person) (int, error) { result := 888 var resultErr error return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) } func TestRegisterer_Do_error(t *testing.T) { // monkey patch calls to the data package defer func(original func(in *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(in *data.Person) (int, error) { var result int resultErr := errors.New("failed to save") return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) } ================================================ FILE: ch06/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*data.Person, error) } // NewGetHandler is the constructor for GetHandler func NewGetHandler(model GetModel) *GetHandler { return &GetHandler{ getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") logging.L.Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) logging.L.Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch06/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &data.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch06/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*data.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch06/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*data.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data output := []*data.Person{} mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch06/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*data.Person, error) { ret := _m.Called(ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(int) *data.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch06/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*data.Person, error) { ret := _m.Called() var r0 []*data.Person if rf, ok := ret.Get(0).(func() []*data.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch06/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: in func (_m *MockRegisterModel) Do(in *data.Person) (int, error) { ret := _m.Called(in) var r0 int if rf, ok := ret.Get(0).(func(*data.Person) int); ok { r0 = rf(in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(*data.Person) error); ok { r1 = rf(in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch06/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch06/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch06/acme/internal/rest/register.go ================================================ package rest import ( "encoding/json" "fmt" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/data" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(in *data.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // register person id, err := h.register(requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch06/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch06/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/gorilla/mux" ) // New will create and initialize the server func New(address string, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: address, handlerGet: NewGetHandler(getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch06/acme/main.go ================================================ package main import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch06/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // build model layer getModel := &get.Getter{} listModel := &list.Lister{} registerModel := ®ister.Registerer{} // start REST server server := rest.New(config.App.Address, getModel, listModel, registerModel) server.Listen(ctx.Done()) } ================================================ FILE: ch06/fake.go ================================================ package ch06 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch06/pcov-html ================================================ #!/usr/bin/env bash if [ "$1" == "" ]; then echo "No input file. Usage: pcov-html ./your-package-dir/" exit 1 fi # Inputs # Trim any ... from the end of the supplied directory DIR=${1%...} # Ensure there is a / at the end of the directory PKG_DIR=${DIR%/}/ # Generated coverage go test $PKG_DIR -coverprofile=$PKG_DIR/coverage.out ${@:1} # Convert coverage to HTML go tool cover -html=$PKG_DIR/coverage.out # Clean up after ourselves rm $PKG_DIR/coverage.out ================================================ FILE: ch07/01_method_injection/01_fprint.go ================================================ package method_injection import ( "fmt" "os" ) func ExampleA() { fmt.Fprint(os.Stdout, "Hello World") } ================================================ FILE: ch07/01_method_injection/02_http_request.go ================================================ package method_injection import ( "bytes" "fmt" "net/http" ) func ExampleB() { // added to make the compiler happy body := &bytes.Buffer{} // example is here req, err := http.NewRequest("POST", "/login", body) // added to make the compiler happy fmt.Printf("req: %#v / err: %s", req, err) } ================================================ FILE: ch07/01_method_injection/03_fprint.go ================================================ package method_injection import ( "io" ) // Fprint formats using the default formats for its operands and writes to w. // It returns the number of bytes written and any write error encountered. func Fprint(w io.Writer, a ...interface{}) (n int, err error) { return } ================================================ FILE: ch07/01_method_injection/04_http_request.go ================================================ package method_injection import ( "io" "io/ioutil" "net/http" "net/url" ) func NewRequest(method, url string, body io.Reader) (*http.Request, error) { // validate method m, err := validateMethod(method) if err != nil { return nil, err } // validate URL u, err := validateURL(url) if err != nil { return nil, err } // process body (if exists) var b io.ReadCloser if body != nil { // read body b = ioutil.NopCloser(body) } // build Request and return req := &http.Request{ URL: u, Method: m, Body: b, } return req, nil } func validateMethod(method string) (string, error) { return "", nil } func validateURL(url string) (*url.URL, error) { return nil, nil } ================================================ FILE: ch07/01_method_injection/05_timestamp_writer_v1.go ================================================ package method_injection import ( "fmt" "io" "time" ) // TimeStampWriterV1 will output the supplied message to writer preceded with a timestamp func TimeStampWriterV1(writer io.Writer, message string) { timestamp := time.Now().Format(time.RFC3339) fmt.Fprintf(writer, "%s -> %s", timestamp, message) } ================================================ FILE: ch07/01_method_injection/06_timestamp_writer_v2.go ================================================ package method_injection import ( "errors" "fmt" "io" "time" ) // TimeStampWriterV2 will output the supplied message to writer preceded with a timestamp func TimeStampWriterV2(writer io.Writer, message string) error { if writer == nil { return errors.New("writer cannot be nil") } timestamp := time.Now().Format(time.RFC3339) fmt.Fprintf(writer, "%s -> %s", timestamp, message) return nil } ================================================ FILE: ch07/01_method_injection/07_timestamp_writer_v3.go ================================================ package method_injection import ( "fmt" "io" "os" "time" ) // TimeStampWriterV3 will output the supplied message to writer preceded with a timestamp func TimeStampWriterV3(writer io.Writer, message string) { if writer == nil { // default to Standard Out writer = os.Stdout } timestamp := time.Now().Format(time.RFC3339) fmt.Fprintf(writer, "%s -> %s", timestamp, message) } ================================================ FILE: ch07/02_advantages/01_handler_v1.go ================================================ package advantages import ( "encoding/json" "net/http" ) func HandlerV1(response http.ResponseWriter, request *http.Request) { garfield := &Animal{ Type: "Cat", Name: "Garfield", } // encode as JSON and output encoder := json.NewEncoder(response) err := encoder.Encode(garfield) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } response.WriteHeader(http.StatusOK) } type Animal struct { Type string Name string } ================================================ FILE: ch07/02_advantages/02_handler_v2.go ================================================ package advantages import ( "encoding/json" "net/http" ) func HandlerV2(response http.ResponseWriter, request *http.Request) { garfield := &Animal{ Type: "Cat", Name: "Garfield", } // encode as JSON and output outputAnimal(response, garfield) } func outputAnimal(response http.ResponseWriter, animal *Animal) { encoder := json.NewEncoder(response) err := encoder.Encode(animal) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } // Happy Path response.WriteHeader(http.StatusOK) } ================================================ FILE: ch07/02_advantages/03_handler_v3.go ================================================ package advantages import ( "encoding/json" "net/http" ) func HandlerV3(response http.ResponseWriter, request *http.Request) { garfield := &Animal{ Type: "Cat", Name: "Garfield", } // encode as JSON and output outputJSON(response, garfield) } func outputJSON(response http.ResponseWriter, data interface{}) { encoder := json.NewEncoder(response) err := encoder.Encode(data) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } // Happy Path response.WriteHeader(http.StatusOK) } ================================================ FILE: ch07/02_advantages/04_context_influence.go ================================================ package advantages import ( "io" "net" "os" ) func WriteLog(writer io.Writer, message string) error { _, err := writer.Write([]byte(message)) return err } func Usage() { // Write to console WriteLog(os.Stdout, "Hello World!") // Write to file file, _ := os.Create("my-log.log") WriteLog(file, "Hello World!") // Write to TCP connection tcpPipe, _ := net.Dial("tcp", "127.0.0.1:1234") WriteLog(tcpPipe, "Hello World!") } ================================================ FILE: ch07/02_advantages/05_person_loader.go ================================================ package advantages import ( "database/sql" "encoding/json" "errors" "net/http" ) var ( // thrown when the supplied order does not exist in the database errNotFound = errors.New("order not found") ) // Loads orders based on supplied owner and order ID type OrderLoader interface { loadOrder(owner Owner, orderID int) (Order, error) } // NewLoadOrderHandler creates a new instance of LoadOrderHandler func NewLoadOrderHandler(loader OrderLoader) *LoadOrderHandler { return &LoadOrderHandler{ loader: loader, } } // LoadOrderHandler is a HTTP handler that loads orders based on the current user and supplied user ID type LoadOrderHandler struct { loader OrderLoader } // ServeHTTP implements http.Handler func (l *LoadOrderHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract user from supplied authentication credentials currentUser, err := l.authenticateUser(request) if err != nil { response.WriteHeader(http.StatusUnauthorized) return } // extract order ID from request orderID, err := l.extractOrderID(request) if err != nil { response.WriteHeader(http.StatusBadRequest) return } // load order using the current user as a request-scoped dependency // (with method injection) order, err := l.loader.loadOrder(currentUser, orderID) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } // output order encoder := json.NewEncoder(response) err = encoder.Encode(order) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } response.WriteHeader(http.StatusOK) } // AuthenticatedLoader will load orders for based on the supplied owner type AuthenticatedLoader struct { // This pool is expensive to create. We will want to create it once and then reuse it. db *sql.DB } // load the order from the database based on owner and order ID func (a *AuthenticatedLoader) loadByOwner(owner Owner, orderID int) (*Order, error) { order, err := a.load(orderID) if err != nil { return nil, err } if order.OwnerID != owner.ID() { // Return not found so we do not leak information to hackers return nil, errNotFound } // happy path return order, nil } func (a *AuthenticatedLoader) load(orderID int) (*Order, error) { // load order from DB return &Order{OwnerID: 1}, nil } type Owner interface { ID() int } type Order struct { OwnerID int // other order details } type User struct { id int // other attributes } func (u *User) ID() int { return u.id } // Extract the user from the request (e.g. from a JWT token). func (l *LoadOrderHandler) authenticateUser(request *http.Request) (*User, error) { return &User{id: 1}, nil } // Extract the order ID from the request (e.g. from the URL or HTTP POST body) func (l *LoadOrderHandler) extractOrderID(request *http.Request) (int, error) { return 2, nil } ================================================ FILE: ch07/04_disadvantages/01_data_struct.go ================================================ package disadvantages import ( "database/sql" "errors" ) // Load people from the database type PersonLoader struct { } func (d *PersonLoader) Load(db *sql.DB, ID int) (*Person, error) { return nil, errors.New("not implemented") } func (d *PersonLoader) LoadAll(db *sql.DB) ([]*Person, error) { return nil, errors.New("not implemented") } type Person struct { Name string Age int } ================================================ FILE: ch07/04_disadvantages/02_ux_improvement.go ================================================ package disadvantages type MyPersonLoader interface { Load(ID int) (*Person, error) } ================================================ FILE: ch07/04_disadvantages/03_many_params.go ================================================ package disadvantages import ( "io" ) type Generator struct{} func (g *Generator) Generate(storage Storage, template io.Reader, destination io.Writer, renderer Renderer, formatter Formatter, params ...interface{}) { } type Storage interface { Load() []interface{} } type Renderer interface { Render(template io.Reader, params ...interface{}) []byte } type Formatter interface { Format([]byte) []byte } ================================================ FILE: ch07/04_disadvantages/04_many_params_v2.go ================================================ package disadvantages import ( "io" ) func NewGeneratorV2(storage Storage, renderer Renderer, formatter Formatter) *GeneratorV2 { return &GeneratorV2{ storage: storage, renderer: renderer, formatter: formatter, } } type GeneratorV2 struct { storage Storage renderer Renderer formatter Formatter } func (g *GeneratorV2) Generate(template io.Reader, destination io.Writer, params ...interface{}) { } ================================================ FILE: ch07/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // App is the application config var App *Config // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string } // Load returns the config loaded from environment func init() { filename, found := os.LookupEnv(DefaultEnvVar) if !found { logging.L.Error("failed to locate file specified by %s", DefaultEnvVar) return } _ = load(filename) } func load(filename string) error { App = &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return err } err = json.Unmarshal(bytes, App) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return err } return nil } ================================================ FILE: ch07/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: &Config{}, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, App, scenario.desc) }) } } ================================================ FILE: ch07/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch07/acme/internal/modules/data/data.go ================================================ package data import ( // import the MySQL Driver "context" "database/sql" "errors" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) var getDB = func() (*sql.DB, error) { if db == nil { if config.App == nil { return nil, errors.New("config is not initialized") } var err error db, err = sql.Open("mysql", config.App.DSN) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func Save(ctx context.Context, in *Person) (int, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { logging.L.Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { logging.L.Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func LoadAll(ctx context.Context) ([]*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { logging.L.Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { logging.L.Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func Load(ctx context.Context, ID int) (*Person, error) { db, err := getDB() if err != nil { logging.L.Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { logging.L.Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } logging.L.Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } func init() { // ensure the config is loaded and the db initialized _, _ = getDB() } ================================================ FILE: ch07/acme/internal/modules/data/data_test.go ================================================ package data import ( "context" "database/sql" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) // replace db for this test db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(ctx, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here defer func(original sql.DB) { // restore original DB (after test) db = &original }(*db) // replace db for this test db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(ctx, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func() (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func() (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(ctx, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function results, err := LoadAll(ctx) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function result, err := Load(ctx, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } ================================================ FILE: ch07/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct{} // Do will perform the conversion func (c *Converter) Do(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, config.App.ExchangeRateBaseURL, config.App.ExchangeRateAPIKey, currency) // perform request req, err := http.NewRequest("GET", url, nil) if err != nil { logging.L.Warn("[exchange] failed to create request. err: %s", err) return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request response, err := http.DefaultClient.Do(req) if err != nil { logging.L.Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) logging.L.Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) logging.L.Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { logging.L.Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { logging.L.Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch07/acme/internal/modules/get/get.go ================================================ package get import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := loader(context.TODO(), ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } // this function as a variable allows us to Monkey Patch during testing var loader = data.Load ================================================ FILE: ch07/acme/internal/modules/get/go_test.go ================================================ package get import ( "context" "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(_ context.Context, ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, ID int) (*data.Person, error) { result := &data.Person{ ID: 1234, FullName: "Doug", } var resultErr error return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // monkey patch calls to the data package defer func(original func(_ context.Context, ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, ID int) (*data.Person, error) { var result *data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(_ context.Context, ID int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, ID int) (*data.Person, error) { var result *data.Person resultErr := errors.New("failed to load person") return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) } ================================================ FILE: ch07/acme/internal/modules/list/list.go ================================================ package list import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { } // Do will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := loader(context.TODO()) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } // this function as a variable allows us to Monkey Patch during testing var loader = data.LoadAll ================================================ FILE: ch07/acme/internal/modules/list/list_test.go ================================================ package list import ( "context" "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // monkey patch calls to the data package defer func(original func(_ context.Context) ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context) ([]*data.Person, error) { result := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } var resultErr error return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) } func TestLister_Do_noResults(t *testing.T) { // monkey patch calls to the data package defer func(original func(_ context.Context) ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context) ([]*data.Person, error) { var result []*data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) } func TestLister_Do_error(t *testing.T) { // monkey patch calls to the data package defer func(original func(_ context.Context) ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context) ([]*data.Person, error) { var result []*data.Person resultErr := errors.New("failed to load people") return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) } ================================================ FILE: ch07/acme/internal/modules/register/register.go ================================================ package register import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/exchange" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { } // Do is API for this struct func (r *Registerer) Do(ctx context.Context, in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { logging.L.Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(ctx, in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(ctx, in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { converter := &exchange.Converter{} price, err := converter.Do(ctx, config.App.BasePrice, currency) if err != nil { logging.L.Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return saver(ctx, person) } // this function as a variable allows us to Monkey Patch during testing var saver = data.Save ================================================ FILE: ch07/acme/internal/modules/register/register_test.go ================================================ package register import ( "context" "errors" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patch calls to the data package defer func(original func(ctx context.Context, in *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(ctx context.Context, in *data.Person) (int, error) { result := 888 var resultErr error return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(ctx, in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) } func TestRegisterer_Do_error(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patch calls to the data package defer func(original func(ctx context.Context, in *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(ctx context.Context, in *data.Person) (int, error) { var result int resultErr := errors.New("failed to save") return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{} ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) } ================================================ FILE: ch07/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*data.Person, error) } // NewGetHandler is the constructor for GetHandler func NewGetHandler(model GetModel) *GetHandler { return &GetHandler{ getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") logging.L.Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) logging.L.Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch07/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &data.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch07/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*data.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch07/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*data.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data output := []*data.Person{} mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch07/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*data.Person, error) { ret := _m.Called(ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(int) *data.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch07/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*data.Person, error) { ret := _m.Called() var r0 []*data.Person if rf, ok := ret.Get(0).(func() []*data.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch07/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: ctx, in func (_m *MockRegisterModel) Do(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch07/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch07/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch07/acme/internal/rest/register.go ================================================ package rest import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/data" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(ctx context.Context, in *data.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // set latency budget for this API subCtx, cancel := context.WithTimeout(request.Context(), 1500*time.Millisecond) defer cancel() // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // call the business logic using the request data and context id, err := h.register(subCtx, requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(ctx context.Context, requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(ctx, person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch07/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch07/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/gorilla/mux" ) // New will create and initialize the server func New(address string, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: address, handlerGet: NewGetHandler(getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch07/acme/main.go ================================================ package main import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch07/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // build model layer getModel := &get.Getter{} listModel := &list.Lister{} registerModel := ®ister.Registerer{} // start REST server server := rest.New(config.App.Address, getModel, listModel, registerModel) server.Listen(ctx.Done()) } ================================================ FILE: ch07/fake.go ================================================ package ch07 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch08/01_config_injection/01_long_constructor.go ================================================ package config_injection import ( "time" ) // NewLongConstructor is the constructor for MyStruct func NewLongConstructor(logger Logger, stats Instrumentation, limiter RateLimiter, cache Cache, timeout time.Duration, workers int) *MyStruct { return &MyStruct{ // code removed } } // MyStruct does something fantastic type MyStruct struct { } // Logger logs stuff type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Info(message string, args ...interface{}) Debug(message string, args ...interface{}) } // Instrumentation records the performances and events type Instrumentation interface { Count(key string, value int) Duration(key string, start time.Time) } // RateLimiter limits how many concurrent requests we can make or process type RateLimiter interface { Acquire() Release() } // Cache will store/retrieve data in a fast way type Cache interface { Store(key string, data []byte) Get(key string) ([]byte, error) } ================================================ FILE: ch08/01_config_injection/02_by_config_example.go ================================================ package config_injection import ( "time" ) // NewByConfigConstructor is the constructor for MyStruct func NewByConfigConstructor(cfg MyConfig, limiter RateLimiter, cache Cache) *MyStruct { return &MyStruct{ // code removed } } // MyConfig defines the config for MyStruct type MyConfig interface { Logger() Logger Instrumentation() Instrumentation Timeout() time.Duration Workers() int } ================================================ FILE: ch08/01_config_injection/03_shared_params.go ================================================ package config_injection import ( "fmt" "time" ) func Usage() { cfg := &fakeConfig{} myFetcher := NewFetcher(cfg, cfg.URL(), cfg.Timeout()) // do something with the object so the compiler does not complain fmt.Printf("%#v", myFetcher) } type FetcherConfig interface { Logger() Logger Instrumentation() Instrumentation } func NewFetcher(cfg FetcherConfig, url string, timeout time.Duration) *MyObject { return nil } type MyObject struct{} // fake implementation of the FetcherConfig interface type fakeConfig struct{} // Logger implements FetcherConfig func (f *fakeConfig) Logger() Logger { return nil } // Instrumentation implements FetcherConfig func (f *fakeConfig) Instrumentation() Instrumentation { return nil } func (f *fakeConfig) URL() string { return "" } func (f *fakeConfig) Timeout() time.Duration { return time.Duration(0) } ================================================ FILE: ch08/02_advantages/01_injected_config/01.go ================================================ package injected_config import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/config" ) func NewMyObject(cfg *config.Config) *MyObject { return &MyObject{ cfg: cfg, } } type MyObject struct { cfg *config.Config } func (m *MyObject) Do() (interface{}, error) { m.cfg.Logger().Error("not implemented") return struct{}{}, nil } ================================================ FILE: ch08/02_advantages/01_injected_config/01_test.go ================================================ package injected_config import ( "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( testConfigLocation = "" ) func TestInjectedConfig(t *testing.T) { // load test config cfg, err := config.LoadFromFile(testConfigLocation) require.NoError(t, err) // build and use object obj := NewMyObject(cfg) result, resultErr := obj.Do() // validate assert.NotNil(t, result) assert.NoError(t, resultErr) } ================================================ FILE: ch08/02_advantages/02_config_injection/02.go ================================================ package config_injection import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/stats" ) func NewMyObject(cfg Config) *MyObject { return &MyObject{ cfg: cfg, } } type Config interface { Logger() *logging.Logger Stats() *stats.Collector } type MyObject struct { cfg Config } func (m *MyObject) Do() (interface{}, error) { m.cfg.Logger().Error("not implemented") return struct{}{}, nil } ================================================ FILE: ch08/02_advantages/02_config_injection/02_test.go ================================================ package config_injection import ( "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/stats" "github.com/stretchr/testify/assert" ) func TestConfigInjection(t *testing.T) { // build test config cfg := &TestConfig{} // build and use object obj := NewMyObject(cfg) result, resultErr := obj.Do() // validate assert.NotNil(t, result) assert.NoError(t, resultErr) } // Simple implementation of the Config interface type TestConfig struct { logger *logging.Logger stats *stats.Collector } func (t *TestConfig) Logger() *logging.Logger { return t.logger } func (t *TestConfig) Stats() *stats.Collector { return t.stats } ================================================ FILE: ch08/02_advantages/03_long_constructor.go ================================================ package config_injection import ( "time" ) // NewLongConstructor is the constructor for MyStruct func NewLongConstructor(logger Logger, stats Instrumentation, limiter RateLimiter, cache Cache, url string, credentials string) *MyStruct { return &MyStruct{ // code removed } } // MyStruct does something fantastic type MyStruct struct { } // Logger logs stuff type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Info(message string, args ...interface{}) Debug(message string, args ...interface{}) } // Instrumentation records the performances and events type Instrumentation interface { Count(key string, value int) Duration(key string, start time.Time) } // RateLimiter limits how many concurrent requests we can make or process type RateLimiter interface { Acquire() Release() } // Cache will store/retrieve data in a fast way type Cache interface { Store(key string, data []byte) Get(key string) ([]byte, error) } ================================================ FILE: ch08/02_advantages/04_by_config_example.go ================================================ package config_injection // NewByConfigConstructor is the constructor for MyStruct func NewByConfigConstructor(cfg MyConfig, url string, credentials string) *MyStruct { return &MyStruct{ // code removed } } // MyConfig defines the config for MyStruct type MyConfig interface { Logger() Logger Instrumentation() Instrumentation RateLimiter() RateLimiter Cache() Cache } ================================================ FILE: ch08/02_advantages/config/config.go ================================================ package config import ( "sync" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/02_advantages/stats" ) // LoadFromFile loads the config from the supplied path func LoadFromFile(path string) (*Config, error) { // TODO: implement return &Config{}, nil } // Config is the result of loading config from a file type Config struct { // Log config LogLevel int `json:"log_level"` logger *logging.Logger loggerInitOnce sync.Once // Instrumentation config StatsDHostAndPort string `json:"stats_d_host_and_port"` stats *stats.Collector statsInitOnce sync.Once // Rate Limiter config RateLimiterMaxConcurrent int `json:"rate_limiter_max_concurrent"` } func (c *Config) Logger() *logging.Logger { c.loggerInitOnce.Do(func() { // use log level to create new logger c.logger = &logging.Logger{ Level: c.LogLevel, } }) return c.logger } func (c *Config) Stats() *stats.Collector { c.statsInitOnce.Do(func() { c.stats = &stats.Collector{ HostAndPort: c.StatsDHostAndPort, } }) return c.stats } ================================================ FILE: ch08/02_advantages/logging/logger.go ================================================ package logging // Logger logs stuff type Logger struct { Level int } // Error outputs a log at level ERROR func (l *Logger) Error(message string, args ...interface{}) { // TODO: implement } // Warn outputs a log at level ERROR func (l *Logger) Warn(message string, args ...interface{}) { // TODO: implement } // Info outputs a log at level ERROR func (l *Logger) Info(message string, args ...interface{}) { // TODO: implement } // Debug outputs a log at level ERROR func (l *Logger) Debug(message string, args ...interface{}) { // TODO: implement } ================================================ FILE: ch08/02_advantages/stats/stats.go ================================================ package stats import ( "time" ) // Collector collects and forwards stats type Collector struct { HostAndPort string } // Count will record an event func (c *Collector) Count(key string, value int) { // TODO: implement } // Count will record the duration of an event func (c *Collector) Duration(key string, start time.Time) { // TODO: implement } ================================================ FILE: ch08/03_applying/01_define_register_config.go ================================================ // +build do-not-build package applying import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" ) // Config is the configuration for the Registerer type Config interface { Logger() *logging.LoggerStdOut BasePrice() float64 } ================================================ FILE: ch08/03_applying/02_register_with_config_injection.go ================================================ // +build do-not-build package applying import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/exchange" ) // NewRegisterer creates and initializes a Registerer func NewRegisterer(cfg Config) *Registerer { return &Registerer{ cfg: cfg, } } // Config is the configuration for the Registerer type Config interface { Logger() logging.Logger RegistrationBasePrice() float64 } // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { cfg Config } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { converter := &exchange.Converter{} price, err := converter.Do(ctx, r.cfg.RegistrationBasePrice(), currency) if err != nil { r.logger().Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } func (r *Registerer) logger() logging.Logger { return r.cfg.Logger() } ================================================ FILE: ch08/03_applying/03_model_before_data_changes.go ================================================ // +build do-not-build package applying import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" ) // GetterConfig is the configuration for Getter type GetterConfig interface { Logger() logging.Logger } // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { cfg GetterConfig } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := loader(context.TODO(), ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } // this function as a variable allows us to Monkey Patch during testing var loader = data.Load ================================================ FILE: ch08/03_applying/04_test_config_link_to_config_package.go ================================================ // +build do-not-build package applying import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" ) type testConfig struct{} // Logger implement Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // RegistrationBasePrice implement Config func (t *testConfig) RegistrationBasePrice() float64 { return 12.34 } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } // ExchangeBaseURL implements Config func (t *testConfig) ExchangeBaseURL() string { return config.App.ExchangeRateBaseURL } // ExchangeAPIKey implements Config func (t *testConfig) ExchangeAPIKey() string { return config.App.ExchangeRateAPIKey } ================================================ FILE: ch08/03_applying/05_result_payload.json ================================================ { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } } ================================================ FILE: ch08/03_applying/06_simple_test_server.go ================================================ // +build do-not-build package applying import ( "net/http" ) type happyExchangeRateService struct{} // ServeHTTP implements http.Handler func (*happyExchangeRateService) ServeHTTP(response http.ResponseWriter, request *http.Request) { payload := []byte(` { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } }`) response.Write(payload) } ================================================ FILE: ch08/04_disadvantages/01_leaking_details.go ================================================ package disadvantages import ( "errors" "strings" ) type PeopleFilterConfig interface { DSN() string } func PeopleFilter(cfg PeopleFilterConfig, filter string) ([]Person, error) { // load people loader := &PersonLoader{} people, err := loader.LoadAll(cfg) if err != nil { return nil, err } // filter people out := []Person{} for _, person := range people { if strings.Contains(person.Name, filter) { out = append(out, person) } } return out, nil } type PersonLoaderConfig interface { DSN() string } type PersonLoader struct{} func (p *PersonLoader) LoadAll(cfg PersonLoaderConfig) ([]Person, error) { return nil, errors.New("not implemented") } // Some data type Person struct { Name string } ================================================ FILE: ch08/04_disadvantages/02_hiding_details.go ================================================ package disadvantages import ( "strings" ) type Loader interface { LoadAll() ([]Person, error) } func PeopleFilterV2(loader Loader, filter string) ([]Person, error) { // load people people, err := loader.LoadAll() if err != nil { return nil, err } // filter people out := []Person{} for _, person := range people { if strings.Contains(person.Name, filter) { out = append(out, person) } } return out, nil } ================================================ FILE: ch08/04_disadvantages/03_unclear_lifecycle.go ================================================ package disadvantages import ( "errors" "time" ) func DoJob(pool WorkerPool, job Job) error { // wait for pool ready := pool.IsReady() select { case <-ready: // happy path case <-time.After(1 * time.Second): return errors.New("timeout waiting for worker pool") } worker := pool.GetWorker() return worker.Do(job) } // Pool of workers type WorkerPool interface { GetWorker() Worker IsReady() chan struct{} } // Executes/processes a unit of work and returns type Worker interface { Do(job Job) error } // A unit of work to be executed against the pool type Job interface { // implementation omitted } ================================================ FILE: ch08/04_disadvantages/04_clear_lifecycle.go ================================================ package disadvantages func DoJobUpdated(pool WorkerPool, job Job) error { worker := pool.GetWorker() return worker.Do(job) } ================================================ FILE: ch08/04_disadvantages/05_layers.go ================================================ package disadvantages func NewLayer1Object(config Layer1Config) *Layer1Object { return &Layer1Object{ MyConfig: config, MyDependency: NewLayer2Object(config), } } // Configuration for the Layer 1 Object type Layer1Config interface { Logger() Logger } // Layer 1 Object type Layer1Object struct { MyConfig Layer1Config MyDependency *Layer2Object } // Configuration for the Layer 2 Object type Layer2Config interface { Logger() Logger } // Layer 2 Object type Layer2Object struct { MyConfig Layer2Config } func NewLayer2Object(config Layer2Config) *Layer2Object { return &Layer2Object{ MyConfig: config, } } // Stub implementation to make the compiler happy type Logger interface { } ================================================ FILE: ch08/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // App is the application config var App *Config // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string // environmental dependencies logger logging.Logger } // Logger returns a reference to the singleton logger func (c *Config) Logger() logging.Logger { if c.logger == nil { c.logger = &logging.LoggerStdOut{} } return c.logger } // RegistrationBasePrice returns the base price for registrations func (c *Config) RegistrationBasePrice() float64 { return c.BasePrice } // DataDSN returns the DSN func (c *Config) DataDSN() string { return c.DSN } // ExchangeBaseURL returns the Base URL from which we can load exchange rates func (c *Config) ExchangeBaseURL() string { return c.ExchangeRateBaseURL } // ExchangeAPIKey returns the DSN func (c *Config) ExchangeAPIKey() string { return c.ExchangeRateAPIKey } // BindAddress returns the host and port this service should bind to func (c *Config) BindAddress() string { return c.Address } // Load returns the config loaded from environment func init() { filename, found := os.LookupEnv(DefaultEnvVar) if !found { logging.L.Error("failed to locate file specified by %s", DefaultEnvVar) return } _ = load(filename) } func load(filename string) error { App = &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return err } err = json.Unmarshal(bytes, App) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return err } return nil } ================================================ FILE: ch08/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: &Config{}, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, App, scenario.desc) }) } } ================================================ FILE: ch08/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // Logger is our standard interface type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch08/acme/internal/modules/data/data.go ================================================ package data import ( // import the MySQL Driver "context" "database/sql" "errors" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // Logger returns a reference to the logger Logger() logging.Logger // DataDSN returns the data source name DataDSN() string } var getDB = func(cfg Config) (*sql.DB, error) { if db == nil { var err error db, err = sql.Open("mysql", cfg.DataDSN()) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func Save(ctx context.Context, cfg Config, in *Person) (int, error) { db, err := getDB(cfg) if err != nil { cfg.Logger().Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { cfg.Logger().Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { cfg.Logger().Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func LoadAll(ctx context.Context, cfg Config) ([]*Person, error) { db, err := getDB(cfg) if err != nil { cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { cfg.Logger().Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func Load(ctx context.Context, cfg Config, ID int) (*Person, error) { db, err := getDB(cfg) if err != nil { cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch08/acme/internal/modules/data/data_test.go ================================================ package data import ( "context" "database/sql" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(ctx, &testConfig{}, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(ctx, &testConfig{}, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func(_ Config) (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func(_ Config) (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function resultID, err := Save(ctx, &testConfig{}, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function results, err := LoadAll(ctx, &testConfig{}) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function result, err := Load(ctx, &testConfig{}, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } type testConfig struct{} // Logger implements Config func (t *testConfig) Logger() logging.Logger { return logging.LoggerStdOut{} } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } ================================================ FILE: ch08/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // NewConverter creates and initializes the converter func NewConverter(cfg Config) *Converter { return &Converter{ cfg: cfg, } } // Config is the config for Converter type Config interface { Logger() logging.Logger ExchangeBaseURL() string ExchangeAPIKey() string } // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct { cfg Config } // Exchange will perform the conversion func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, c.cfg.ExchangeBaseURL(), c.cfg.ExchangeAPIKey(), currency) // perform request req, err := http.NewRequest("GET", url, nil) if err != nil { c.logger().Warn("[exchange] failed to create request. err: %s", err) return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request response, err := http.DefaultClient.Do(req) if err != nil { c.logger().Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) c.logger().Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) c.logger().Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { c.logger().Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { c.logger().Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } func (c *Converter) logger() logging.Logger { return c.cfg.Logger() } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch08/acme/internal/modules/exchange/converter_ext_bounday_test.go ================================================ // +build external package exchange import ( "context" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExternalBoundaryTest(t *testing.T) { // define the config cfg := &testConfig{ baseURL: config.App.ExchangeRateBaseURL, apiKey: config.App.ExchangeRateAPIKey, } // create a converter to test converter := NewConverter(cfg) // fetch from the server response, err := converter.loadRateFromServer(context.Background(), "AUD") require.NotNil(t, response) require.NoError(t, err) // parse the response resultRate, err := converter.extractRate(response, "AUD") require.NoError(t, err) // validate the result assert.True(t, resultRate > 0) } ================================================ FILE: ch08/acme/internal/modules/exchange/converter_int_bounday_test.go ================================================ package exchange import ( "context" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/stretchr/testify/assert" ) func TestInternalBoundaryTest(t *testing.T) { // start our test server server := httptest.NewServer(&happyExchangeRateService{}) defer server.Close() // define the config cfg := &testConfig{ baseURL: server.URL, apiKey: "", } // create a converter to test converter := NewConverter(cfg) resultRate, resultErr := converter.Exchange(context.Background(), 100.00, "AUD") // validate the result assert.Equal(t, 101.01, resultRate) assert.NoError(t, resultErr) } type happyExchangeRateService struct{} // ServeHTTP implements http.Handler func (*happyExchangeRateService) ServeHTTP(response http.ResponseWriter, request *http.Request) { payload := []byte(` { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } }`) response.Write(payload) } // test implementation of Config type testConfig struct { baseURL string apiKey string } // Logger implements Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // ExchangeBaseURL implements Config func (t *testConfig) ExchangeBaseURL() string { return t.baseURL } // ExchangeAPIKey implements Config func (t *testConfig) ExchangeAPIKey() string { return t.apiKey } ================================================ FILE: ch08/acme/internal/modules/get/get.go ================================================ package get import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // NewGetter creates and initializes a Getter func NewGetter(cfg Config) *Getter { return &Getter{ cfg: cfg, } } // Config is the configuration for Getter type Config interface { Logger() logging.Logger DataDSN() string } // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { cfg Config } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := loader(context.TODO(), g.cfg, ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } // this function as a variable allows us to Monkey Patch during testing var loader = data.Load ================================================ FILE: ch08/acme/internal/modules/get/go_test.go ================================================ package get import ( "context" "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config, _ int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, _ data.Config, _ int) (*data.Person, error) { result := &data.Person{ ID: 1234, FullName: "Doug", } var resultErr error return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config, _ int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, _ data.Config, _ int) (*data.Person, error) { var result *data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config, _ int) (*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, _ data.Config, _ int) (*data.Person, error) { var result *data.Person resultErr := errors.New("failed to load person") return result, resultErr } // end of monkey patch // call method getter := &Getter{} person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) } ================================================ FILE: ch08/acme/internal/modules/list/list.go ================================================ package list import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // NewLister creates and initializes a Lister func NewLister(cfg Config) *Lister { return &Lister{ cfg: cfg, } } // Config is the config for Lister type Config interface { Logger() logging.Logger DataDSN() string } // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { cfg Config } // Exchange will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := loader(context.TODO(), l.cfg) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } // this function as a variable allows us to Monkey Patch during testing var loader = data.LoadAll ================================================ FILE: ch08/acme/internal/modules/list/list_test.go ================================================ package list import ( "context" "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config) ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, _ data.Config) ([]*data.Person, error) { result := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } var resultErr error return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) } func TestLister_Do_noResults(t *testing.T) { // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config) ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, _ data.Config) ([]*data.Person, error) { var result []*data.Person resultErr := data.ErrNotFound return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) } func TestLister_Do_error(t *testing.T) { // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config) ([]*data.Person, error)) { // restore original loader = original }(loader) // replace method loader = func(_ context.Context, _ data.Config) ([]*data.Person, error) { var result []*data.Person resultErr := errors.New("failed to load people") return result, resultErr } // end of monkey patch // call method lister := &Lister{} persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) } ================================================ FILE: ch08/acme/internal/modules/register/register.go ================================================ package register import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // NewRegisterer creates and initializes a Registerer func NewRegisterer(cfg Config, exchanger Exchanger) *Registerer { return &Registerer{ cfg: cfg, exchanger: exchanger, } } // Exchanger will convert from one currency to another type Exchanger interface { // Exchange will perform the conversion Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) } // Config is the configuration for the Registerer type Config interface { Logger() logging.Logger RegistrationBasePrice() float64 DataDSN() string } // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { cfg Config exchanger Exchanger } // Do is API for this struct func (r *Registerer) Do(ctx context.Context, in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { r.logger().Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(ctx, in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(ctx, in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { price, err := r.exchanger.Exchange(ctx, r.cfg.RegistrationBasePrice(), currency) if err != nil { r.logger().Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return saver(ctx, r.cfg, person) } func (r *Registerer) logger() logging.Logger { return r.cfg.Logger() } // this function as a variable allows us to Monkey Patch during testing var saver = data.Save ================================================ FILE: ch08/acme/internal/modules/register/register_test.go ================================================ package register import ( "context" "errors" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config, _ *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(_ context.Context, _ data.Config, _ *data.Person) (int, error) { result := 888 var resultErr error return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, } ID, err := registerer.Do(ctx, in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) } func TestRegisterer_Do_error(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patch calls to the data package defer func(original func(_ context.Context, _ data.Config, _ *data.Person) (int, error)) { // restore original saver = original }(saver) // replace method saver = func(_ context.Context, _ data.Config, _ *data.Person) (int, error) { var result int resultErr := errors.New("failed to save") return result, resultErr } // end of monkey patch // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) } // Stub implementation of Config type testConfig struct{} // Logger implement Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // RegistrationBasePrice implement Config func (t *testConfig) RegistrationBasePrice() float64 { return 12.34 } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } type stubExchanger struct{} // Exchange implements Exchanger func (s stubExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { return 12.34, nil } ================================================ FILE: ch08/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*data.Person, error) } // GetConfig is the config for the Get Handler type GetConfig interface { Logger() logging.Logger } // NewGetHandler is the constructor for GetHandler func NewGetHandler(cfg GetConfig, model GetModel) *GetHandler { return &GetHandler{ cfg: cfg, getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { cfg GetConfig getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") h.cfg.Logger().Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) h.cfg.Logger().Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch08/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &data.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(&testConfig{}, mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } type testConfig struct { } func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } func (*testConfig) BindAddress() string { return "0.0.0.0:0" } ================================================ FILE: ch08/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*data.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch08/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*data.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data output := []*data.Person{} mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch08/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*data.Person, error) { ret := _m.Called(ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(int) *data.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch08/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*data.Person, error) { ret := _m.Called() var r0 []*data.Person if rf, ok := ret.Get(0).(func() []*data.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch08/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: ctx, in func (_m *MockRegisterModel) Do(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch08/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch08/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch08/acme/internal/rest/register.go ================================================ package rest import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/data" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(ctx context.Context, in *data.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // set latency budget for this API subCtx, cancel := context.WithTimeout(request.Context(), 1500*time.Millisecond) defer cancel() // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // call the business logic using the request data and context id, err := h.register(subCtx, requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(ctx context.Context, requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(ctx, person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch08/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch08/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/logging" "github.com/gorilla/mux" ) // Config is the config for the REST package type Config interface { Logger() logging.Logger BindAddress() string } // New will create and initialize the server func New(cfg Config, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: cfg.BindAddress(), handlerGet: NewGetHandler(cfg, getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch08/acme/main.go ================================================ package main import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // build the exchanger exchanger := exchange.NewConverter(config.App) // build model layer getModel := get.NewGetter(config.App) listModel := list.NewLister(config.App) registerModel := register.NewRegisterer(config.App, exchanger) // start REST server server := rest.New(config.App, getModel, listModel, registerModel) server.Listen(ctx.Done()) } ================================================ FILE: ch08/fake.go ================================================ package ch08 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch09/01_jit_injection/01_injecting_db.go ================================================ package jit_injection func NewMyLoadPersonLogic(ds DataSource) *MyLoadPersonLogic { return &MyLoadPersonLogic{ dataSource: ds, } } type MyLoadPersonLogic struct { dataSource DataSource } // Load person by supplied ID func (m *MyLoadPersonLogic) Load(ID int) (Person, error) { return m.dataSource.Load(ID) } type DataSource interface { // Load person by ID Load(ID int) (Person, error) } type Person struct { Name string } ================================================ FILE: ch09/01_jit_injection/01_injecting_db_test.go ================================================ package jit_injection import ( "testing" "github.com/stretchr/testify/assert" ) func TestMyLoadPersonLogic(t *testing.T) { // setup the mock db mockDB := &mockDB{ out: Person{Name: "Fred"}, } // call the object we are testing testObj := NewMyLoadPersonLogic(mockDB) result, resultErr := testObj.Load(123) // validate expectations assert.Equal(t, Person{Name: "Fred"}, result) assert.Nil(t, resultErr) } // mock implementation of DataSource type mockDB struct { out Person outErr error } // Load implements DataSource func (m *mockDB) Load(ID int) (Person, error) { return m.out, m.outErr } ================================================ FILE: ch09/01_jit_injection/02_injecting_business_logic.go ================================================ package jit_injection import ( "errors" "net/http" ) func NewLoadPersonHandler(logic LoadPersonLogic) *LoadPersonHandler { return &LoadPersonHandler{ businessLogic: logic, } } type LoadPersonHandler struct { businessLogic LoadPersonLogic } func (h *LoadPersonHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { requestedID, err := h.extractInputFromRequest(request) output, err := h.businessLogic.Load(requestedID) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } h.writeOutput(response, output) } // extract the person ID from the request func (h *LoadPersonHandler) extractInputFromRequest(request *http.Request) (int, error) { return 0, errors.New("not implemented yet") } // convert person to JSON and write to the HTTP response func (h *LoadPersonHandler) writeOutput(writer http.ResponseWriter, person Person) { // not implemented yet } type LoadPersonLogic interface { // Load person by supplied ID Load(ID int) (Person, error) } ================================================ FILE: ch09/01_jit_injection/03_injecting_db_jit.go ================================================ package jit_injection import ( "errors" ) type MyLoadPersonLogicJIT struct { dataSource DataSourceJIT } // Load person by supplied ID func (m *MyLoadPersonLogicJIT) Load(ID int) (Person, error) { return m.getDataSource().Load(ID) } func (m *MyLoadPersonLogicJIT) getDataSource() DataSourceJIT { if m.dataSource == nil { m.dataSource = NewMyDataSourceJIT() } return m.dataSource } type DataSourceJIT interface { // Load person by ID Load(ID int) (Person, error) } func NewMyDataSourceJIT() *MyDataSourceJIT { return &MyDataSourceJIT{} } // Default implementation of DataSourceJIT type MyDataSourceJIT struct { } func (m *MyDataSourceJIT) Load(ID int) (Person, error) { return Person{}, errors.New("not implemented yet") } ================================================ FILE: ch09/01_jit_injection/03_injecting_db_jit_test.go ================================================ package jit_injection import ( "testing" "github.com/stretchr/testify/assert" ) func TestMyLoadPersonLogicJIT(t *testing.T) { // setup the mock db mockDB := &mockDB{ out: Person{Name: "Fred"}, } // call the object we are testing testObj := MyLoadPersonLogicJIT{ dataSource: mockDB, } result, resultErr := testObj.Load(123) // validate expectations assert.Equal(t, Person{Name: "Fred"}, result) assert.Nil(t, resultErr) } ================================================ FILE: ch09/01_jit_injection/04_noop_debugger.go ================================================ package jit_injection import ( "errors" ) type ObjectWithDebugger struct { Debugger Debugger } func (o *ObjectWithDebugger) DoSomethingAmazing(input int) error { o.getDebugger().Log("input was: %d", input) err := o.doSomething() o.getDebugger().Log("result was: %v", err) return err } func (o *ObjectWithDebugger) getDebugger() Debugger { if o.Debugger == nil { o.Debugger = &noopDebugger{} } return o.Debugger } func (o *ObjectWithDebugger) doSomething() error { return errors.New("not implemented yet") } type Debugger interface { Log(msg string, args ...interface{}) } // NO-OP implementation of the Debugger interface type noopDebugger struct { // intentionally left blank } // Log implements Debugger func (n *noopDebugger) Log(_ string, args ...interface{}) { // intentionally does nothing } ================================================ FILE: ch09/02_advantages/01_long_constructor.go ================================================ package advantages import ( "io" ) func NewGenerator(storage Storage, renderer Renderer, template io.Reader) *Generator { return &Generator{ storage: storage, renderer: renderer, template: template, } } type Generator struct { storage Storage renderer Renderer template io.Reader } func (g *Generator) Generate(destination io.Writer, params ...interface{}) { } type Storage interface { Load() []interface{} } type Renderer interface { Render(template io.Reader, params ...interface{}) []byte } ================================================ FILE: ch09/02_advantages/02_short_constructor.go ================================================ package advantages import ( "io" ) func NewGeneratorV2(template io.Reader) *Generator { return &Generator{ template: template, } } func (g *Generator) getStorage() Storage { if g.storage == nil { g.storage = &DefaultStorage{} } return g.storage } func (g *Generator) getRenderer() Renderer { if g.renderer == nil { g.renderer = &DefaultRenderer{} } return g.renderer } // Default implementation of Storage type DefaultStorage struct{} // Load implements Storage func (d *DefaultStorage) Load() []interface{} { return nil } // Default implementation of Storage type DefaultRenderer struct{} // Load implements Renderer func (d *DefaultRenderer) Render(template io.Reader, params ...interface{}) []byte { return nil } ================================================ FILE: ch09/02_advantages/03_optional_dep_without_jitdi.go ================================================ package advantages func NewLoaderWithoutJIT(ds Datastore) *LoaderWithoutJIT { return &LoaderWithoutJIT{ datastore: ds, } } type LoaderWithoutJIT struct { // required private dependency datastore Datastore // optional cache OptionalCache Cache } func (l *LoaderWithoutJIT) Load(ID int) (*Animal, error) { var output *Animal var err error // attempt to load from cache if l.OptionalCache != nil { output = l.OptionalCache.Get(ID) if output != nil { // return cached value return output, nil } } // load from data store output, err = l.datastore.Load(ID) if err != nil { return nil, err } // cache the loaded value if l.OptionalCache != nil { l.OptionalCache.Put(ID, output) } // output the result return output, nil } type Cache interface { Get(ID int) *Animal Put(ID int, value *Animal) } type Datastore interface { Load(ID int) (*Animal, error) Save(ID int, value *Animal) error } type Animal struct { // some data fields go here } ================================================ FILE: ch09/02_advantages/04_optional_dep_with_jitdi.go ================================================ package advantages func NewLoaderWithJIT(ds Datastore) *LoaderWithJIT { return &LoaderWithJIT{ datastore: ds, } } type LoaderWithJIT struct { // required private dependency datastore Datastore // optional cache OptionalCache Cache } func (l *LoaderWithJIT) Load(ID int) (*Animal, error) { // attempt to load from cache output := l.cache().Get(ID) if output != nil { // return cached value return output, nil } // load from data store output, err := l.datastore.Load(ID) if err != nil { return nil, err } // cache the loaded value l.cache().Put(ID, output) // output the result return output, nil } func (l *LoaderWithJIT) cache() Cache { if l.OptionalCache == nil { l.OptionalCache = &noopCache{} } return l.OptionalCache } // NO-OP implementation of the cache type noopCache struct { // intentionally blank } func (n *noopCache) Get(ID int) *Animal { // intentionally does nothing return nil } func (n *noopCache) Put(ID int, value *Animal) { // intentionally does nothing } ================================================ FILE: ch09/02_advantages/05_loader.go ================================================ package advantages import ( "errors" ) func NewLoader(ds Datastore, cache Cache) *MyLoader { return &MyLoader{ ds: ds, cache: cache, } } type MyLoader struct { ds Datastore cache Cache } func (l *MyLoader) LoadAll() ([]interface{}, error) { return nil, errors.New("not implemented") } ================================================ FILE: ch09/02_advantages/06_global_variable/06_global_variable.go ================================================ package global_variable // Global singleton of connections to our data store var storage UserStorage type Saver struct { } func (s *Saver) Do(in *User) error { err := s.validate(in) if err != nil { return err } return storage.Save(in) } func (s *Saver) validate(in *User) error { // validate user and return error when there is a problem return nil } type UserStorage interface { Save(in *User) error } type User struct { Name string Password string } ================================================ FILE: ch09/02_advantages/07_global_variable_jit/07_global_variable_jit.go ================================================ package global_variable_jit // Global singleton of connections to our data store var storage UserStorage type Saver struct { storage UserStorage } func (s *Saver) Do(in *User) error { err := s.validate(in) if err != nil { return err } return s.getStorage().Save(in) } // Just-in-time DI func (s *Saver) getStorage() UserStorage { if s.storage == nil { s.storage = storage } return s.storage } func (s *Saver) validate(in *User) error { // validate user and return error when there is a problem return nil } type UserStorage interface { Save(in *User) error } type User struct { Name string Password string } ================================================ FILE: ch09/02_advantages/07_global_variable_jit/07_global_variable_jit_test.go ================================================ package global_variable_jit import ( "testing" "github.com/stretchr/testify/assert" ) func TestSaver_Do(t *testing.T) { // input carol := &User{ Name: "Carol", Password: "IamKing", } // mocks/stubs stubStorage := &StubUserStorage{} // do call saver := &Saver{ storage: stubStorage, } resultErr := saver.Do(carol) // validate assert.NotEqual(t, resultErr, "unexpected error") } // Stub implementation of UserStorage type StubUserStorage struct{} func (s *StubUserStorage) Save(_ *User) error { // return "happy path" return nil } ================================================ FILE: ch09/02_advantages/08_car_v1.go ================================================ package advantages type CarV1 struct { engine Engine } func (c *CarV1) Drive() { c.engine.Start() defer c.engine.Stop() c.engine.Drive() } type Engine interface { Start() Drive() Stop() } ================================================ FILE: ch09/02_advantages/09_car_v2.go ================================================ package advantages type CarV2 struct { engine Engine } func (c *CarV2) Drive() { engine := c.getEngine() engine.Start() defer engine.Stop() engine.Drive() } func (c *CarV2) getEngine() Engine { if c.engine == nil { c.engine = newEngine() } return c.engine } func newEngine() Engine { // not implemented return nil } ================================================ FILE: ch09/03_applying/01_commands.sh ================================================ #!/usr/bin/env bash cd $GOPATH/src/github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch08/ package-coverage -a -prefix $(go list)/ ./acme/ ================================================ FILE: ch09/03_applying/02_coverage.txt ================================================ ------------------------------------------------------------------------- | Branch | Dir | | | Cov% | Stmts | Cov% | Stmts | Package | ------------------------------------------------------------------------- | 65.66 | 265 | 0.00 | 7 | acme/ | | 47.83 | 23 | 47.83 | 23 | acme/internal/config/ | | 0.00 | 4 | 0.00 | 4 | acme/internal/logging/ | | 73.77 | 61 | 73.77 | 61 | acme/internal/modules/data/ | | 61.70 | 47 | 61.70 | 47 | acme/internal/modules/exchange/ | | 85.71 | 7 | 85.71 | 7 | acme/internal/modules/get/ | | 46.15 | 13 | 46.15 | 13 | acme/internal/modules/list/ | | 62.07 | 29 | 62.07 | 29 | acme/internal/modules/register/ | | 79.73 | 74 | 79.73 | 74 | acme/internal/rest/ | ------------------------------------------------------------------------- ================================================ FILE: ch09/03_applying/03_initial_dao.go ================================================ //+build willNotCompile package applying import ( "context" ) // NewDAO will initialize the database connection pool (if not already done) and return a data access object which // can be used to interact with the database func NewDAO(cfg Config) *DAO { // initialize the db connection pool _, _ = getDB(cfg) return &DAO{ cfg: cfg, } } // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { return Load(ctx, d.cfg, ID) } ================================================ FILE: ch09/04_disadvantages/01_uncertain_init_state.go ================================================ package disadvantages import ( "context" "errors" "net" "sync" ) type ConnectionPool interface { IsReady() <-chan struct{} Get() net.Conn Release(conn net.Conn) } type Sender struct { connectionPool ConnectionPool initPoolOnce sync.Once } func (l *Sender) Send(ctx context.Context, payload []byte) error { pool := l.getConnectionPool() // ensure pool is ready select { case <-pool.IsReady(): // happy path case <-ctx.Done(): // context timed out or was cancelled return errors.New("failed to get connection") } // get connection from pool and return afterwards conn := pool.Get() defer l.connectionPool.Release(conn) // send and return _, err := conn.Write(payload) return err } func (l *Sender) getConnectionPool() ConnectionPool { // Inject the connection pool with JIT DI if l.connectionPool == nil { myPool := &myConnectionPool{} go myPool.init() l.connectionPool = myPool } return l.connectionPool } // default implementation of the connection pool type myConnectionPool struct { } // IsReady implements ConnectionPool func (m *myConnectionPool) IsReady() <-chan struct{} { // not implemented yet return make(chan struct{}) } // IsReady implements ConnectionPool func (m *myConnectionPool) Get() net.Conn { // not implemented yet return nil } // IsReady implements ConnectionPool func (m *myConnectionPool) Release(_ net.Conn) { // not implemented yet } func (m *myConnectionPool) init() { // create connection and populate the pool } ================================================ FILE: ch09/04_disadvantages/02_certain_init_state.go ================================================ package disadvantages func (l *Sender) SendWithoutReadyCheck(payload []byte) error { pool := l.getConnectionPool() // get connection from pool and return afterwards conn := pool.Get() defer l.connectionPool.Release(conn) // send and return _, err := conn.Write(payload) return err } ================================================ FILE: ch09/04_disadvantages/03_cpool_slow_constructor.go ================================================ package disadvantages func newConnectionPool() ConnectionPool { pool := &myConnectionPool{} // initialize the pool pool.init() // return a "ready to use pool" return pool } ================================================ FILE: ch09/04_disadvantages/04_get_pool_with_once.go ================================================ package disadvantages func (l *Sender) getConnectionPoolOnce() ConnectionPool { l.initPoolOnce.Do(func() { l.connectionPool = newConnectionPool() }) return l.connectionPool } ================================================ FILE: ch09/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // App is the application config var App *Config // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string // environmental dependencies logger logging.Logger } // Logger returns a reference to the singleton logger func (c *Config) Logger() logging.Logger { if c.logger == nil { c.logger = &logging.LoggerStdOut{} } return c.logger } // RegistrationBasePrice returns the base price for registrations func (c *Config) RegistrationBasePrice() float64 { return c.BasePrice } // DataDSN returns the DSN func (c *Config) DataDSN() string { return c.DSN } // ExchangeBaseURL returns the Base URL from which we can load exchange rates func (c *Config) ExchangeBaseURL() string { return c.ExchangeRateBaseURL } // ExchangeAPIKey returns the DSN func (c *Config) ExchangeAPIKey() string { return c.ExchangeRateAPIKey } // BindAddress returns the host and port this service should bind to func (c *Config) BindAddress() string { return c.Address } // Load returns the config loaded from environment func init() { filename, found := os.LookupEnv(DefaultEnvVar) if !found { logging.L.Error("failed to locate file specified by %s", DefaultEnvVar) return } _ = load(filename) } func load(filename string) error { App = &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return err } err = json.Unmarshal(bytes, App) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return err } return nil } ================================================ FILE: ch09/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: &Config{}, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, App, scenario.desc) }) } } ================================================ FILE: ch09/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // Logger is our standard interface type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch09/acme/internal/modules/data/dao.go ================================================ package data import ( "context" "database/sql" "time" ) // NewDAO will initialize the database connection pool (if not already done) and return a data access object which // can be used to interact with the database func NewDAO(cfg Config) *DAO { // initialize the db connection pool _, _ = getDB(cfg) return &DAO{ cfg: cfg, } } // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config // Tracker is an optional query timer Tracker QueryTracker } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { // track processing time defer d.getTracker().Track("Load", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { d.cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) LoadAll(ctx context.Context) ([]*Person, error) { // track processing time defer d.getTracker().Track("LoadAll", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { d.cfg.Logger().Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func (d *DAO) Save(ctx context.Context, in *Person) (int, error) { // track processing time defer d.getTracker().Track("Save", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { d.cfg.Logger().Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { d.cfg.Logger().Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } func (d *DAO) getTracker() QueryTracker { if d.Tracker == nil { d.Tracker = &noopTracker{} } return d.Tracker } ================================================ FILE: ch09/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // Logger returns a reference to the logger Logger() logging.Logger // DataDSN returns the data source name DataDSN() string } var getDB = func(cfg Config) (*sql.DB, error) { if db == nil { var err error db, err = sql.Open("mysql", cfg.DataDSN()) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch09/acme/internal/modules/data/data_test.go ================================================ package data import ( "context" "database/sql" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func(_ Config) (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func(_ Config) (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) results, err := dao.LoadAll(ctx) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) result, err := dao.Load(ctx, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } type testConfig struct{} // Logger implements Config func (t *testConfig) Logger() logging.Logger { return logging.LoggerStdOut{} } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } ================================================ FILE: ch09/acme/internal/modules/data/tracker.go ================================================ package data import ( "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" ) // QueryTracker is an interface to track query timing type QueryTracker interface { // Track will record/out the time a query took by calculating time.Now().Sub(start) Track(key string, start time.Time) } // NO-OP implementation of QueryTracker type noopTracker struct{} // Track implements QueryTracker func (_ *noopTracker) Track(_ string, _ time.Time) { // intentionally does nothing } // NewLogTracker returns a Tracker that outputs tracking data to log func NewLogTracker(logger logging.Logger) *LogTracker { return &LogTracker{ logger: logger, } } // LogTracker implements QueryTracker and outputs to the supplied logger type LogTracker struct { logger logging.Logger } // Track implements QueryTracker func (l *LogTracker) Track(key string, start time.Time) { l.logger.Info("[%s] Timing: %s\n", key, time.Now().Sub(start).String()) } ================================================ FILE: ch09/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // NewConverter creates and initializes the converter func NewConverter(cfg Config) *Converter { return &Converter{ cfg: cfg, } } // Config is the config for Converter type Config interface { Logger() logging.Logger ExchangeBaseURL() string ExchangeAPIKey() string } // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct { cfg Config } // Exchange will perform the conversion func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, c.cfg.ExchangeBaseURL(), c.cfg.ExchangeAPIKey(), currency) // perform request req, err := http.NewRequest("GET", url, nil) if err != nil { c.logger().Warn("[exchange] failed to create request. err: %s", err) return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request response, err := http.DefaultClient.Do(req) if err != nil { c.logger().Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) c.logger().Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) c.logger().Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { c.logger().Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { c.logger().Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } func (c *Converter) logger() logging.Logger { return c.cfg.Logger() } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch09/acme/internal/modules/exchange/converter_ext_bounday_test.go ================================================ // +build external package exchange import ( "context" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExternalBoundaryTest(t *testing.T) { // define the config cfg := &testConfig{ baseURL: config.App.ExchangeRateBaseURL, apiKey: config.App.ExchangeRateAPIKey, } // create a converter to test converter := NewConverter(cfg) // fetch from the server response, err := converter.loadRateFromServer(context.Background(), "AUD") require.NotNil(t, response) require.NoError(t, err) // parse the response resultRate, err := converter.extractRate(response, "AUD") require.NoError(t, err) // validate the result assert.True(t, resultRate > 0) } ================================================ FILE: ch09/acme/internal/modules/exchange/converter_int_bounday_test.go ================================================ package exchange import ( "context" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/stretchr/testify/assert" ) func TestInternalBoundaryTest(t *testing.T) { // start our test server server := httptest.NewServer(&happyExchangeRateService{}) defer server.Close() // define the config cfg := &testConfig{ baseURL: server.URL, apiKey: "", } // create a converter to test converter := NewConverter(cfg) resultRate, resultErr := converter.Exchange(context.Background(), 100.00, "AUD") // validate the result assert.Equal(t, 101.01, resultRate) assert.NoError(t, resultErr) } type happyExchangeRateService struct{} // ServeHTTP implements http.Handler func (*happyExchangeRateService) ServeHTTP(response http.ResponseWriter, request *http.Request) { payload := []byte(` { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } }`) response.Write(payload) } // test implementation of Config type testConfig struct { baseURL string apiKey string } // Logger implements Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // ExchangeBaseURL implements Config func (t *testConfig) ExchangeBaseURL() string { return t.baseURL } // ExchangeAPIKey implements Config func (t *testConfig) ExchangeAPIKey() string { return t.apiKey } ================================================ FILE: ch09/acme/internal/modules/get/get.go ================================================ package get import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // NewGetter creates and initializes a Getter func NewGetter(cfg Config) *Getter { return &Getter{ cfg: cfg, } } // Config is the configuration for Getter type Config interface { Logger() logging.Logger DataDSN() string } // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { cfg Config data myLoader } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := g.getLoader().Load(context.TODO(), ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } func (g *Getter) getLoader() myLoader { if g.data == nil { g.data = data.NewDAO(g.cfg) } return g.data } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { Load(ctx context.Context, ID int) (*data.Person, error) } ================================================ FILE: ch09/acme/internal/modules/get/go_test.go ================================================ package get import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockResult := &data.Person{ ID: 1234, FullName: "Doug", } mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(mockResult, nil).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, data.ErrNotFound).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, errors.New("something failed")).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch09/acme/internal/modules/get/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package get import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // Load provides a mock function with given fields: ctx, ID func (_m *mockMyLoader) Load(ctx context.Context, ID int) (*data.Person, error) { ret := _m.Called(ctx, ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(context.Context, int) *data.Person); ok { r0 = rf(ctx, ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { r1 = rf(ctx, ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch09/acme/internal/modules/list/list.go ================================================ package list import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // NewLister creates and initializes a Lister func NewLister(cfg Config) *Lister { return &Lister{ cfg: cfg, } } // Config is the config for Lister type Config interface { Logger() logging.Logger DataDSN() string } // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { cfg Config data myLoader } // Exchange will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := l.getLoader().LoadAll(context.TODO()) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } func (l *Lister) getLoader() myLoader { if l.data == nil { l.data = data.NewDAO(l.cfg) // temporarily add a log tracker l.data.(*data.DAO).Tracker = data.NewLogTracker(l.cfg.Logger()) } return l.data } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { LoadAll(ctx context.Context) ([]*data.Person, error) } ================================================ FILE: ch09/acme/internal/modules/list/list_test.go ================================================ package list import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // configure the mock loader mockResult := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(mockResult, nil).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_noResults(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, data.ErrNotFound).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_error(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, errors.New("something failed")).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch09/acme/internal/modules/list/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package list import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // LoadAll provides a mock function with given fields: ctx func (_m *mockMyLoader) LoadAll(ctx context.Context) ([]*data.Person, error) { ret := _m.Called(ctx) var r0 []*data.Person if rf, ok := ret.Get(0).(func(context.Context) []*data.Person); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch09/acme/internal/modules/register/mock_my_saver_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package register import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMySaver is an autogenerated mock type for the mySaver type type mockMySaver struct { mock.Mock } // Save provides a mock function with given fields: ctx, in func (_m *mockMySaver) Save(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch09/acme/internal/modules/register/register.go ================================================ package register import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // NewRegisterer creates and initializes a Registerer func NewRegisterer(cfg Config, exchanger Exchanger) *Registerer { return &Registerer{ cfg: cfg, exchanger: exchanger, } } // Exchanger will convert from one currency to another type Exchanger interface { // Exchange will perform the conversion Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) } // Config is the configuration for the Registerer type Config interface { Logger() logging.Logger RegistrationBasePrice() float64 DataDSN() string } // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { cfg Config exchanger Exchanger data mySaver } // Do is API for this struct func (r *Registerer) Do(ctx context.Context, in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { r.logger().Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(ctx, in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(ctx, in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { price, err := r.exchanger.Exchange(ctx, r.cfg.RegistrationBasePrice(), currency) if err != nil { r.logger().Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return r.getSaver().Save(ctx, person) } func (r *Registerer) getSaver() mySaver { if r.data == nil { r.data = data.NewDAO(r.cfg) } return r.data } func (r *Registerer) logger() logging.Logger { return r.cfg.Logger() } //go:generate mockery -name=mySaver -case underscore -testonly -inpkg -note @generated type mySaver interface { Save(ctx context.Context, in *data.Person) (int, error) } ================================================ FILE: ch09/acme/internal/modules/register/register_test.go ================================================ package register import ( "context" "errors" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // configure the mock saver mockResult := 888 mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(mockResult, nil).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) assert.True(t, mockSaver.AssertExpectations(t)) } func TestRegisterer_Do_error(t *testing.T) { // configure the mock saver mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(0, errors.New("something failed")).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) assert.True(t, mockSaver.AssertExpectations(t)) } // Stub implementation of Config type testConfig struct{} // Logger implement Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // RegistrationBasePrice implement Config func (t *testConfig) RegistrationBasePrice() float64 { return 12.34 } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } type stubExchanger struct{} // Exchange implements Exchanger func (s stubExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { return 12.34, nil } ================================================ FILE: ch09/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*data.Person, error) } // GetConfig is the config for the Get Handler type GetConfig interface { Logger() logging.Logger } // NewGetHandler is the constructor for GetHandler func NewGetHandler(cfg GetConfig, model GetModel) *GetHandler { return &GetHandler{ cfg: cfg, getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { cfg GetConfig getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") h.cfg.Logger().Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) h.cfg.Logger().Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch09/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &data.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(&testConfig{}, mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } type testConfig struct { } func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } func (*testConfig) BindAddress() string { return "0.0.0.0:0" } ================================================ FILE: ch09/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*data.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch09/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*data.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data output := []*data.Person{} mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch09/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*data.Person, error) { ret := _m.Called(ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(int) *data.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch09/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*data.Person, error) { ret := _m.Called() var r0 []*data.Person if rf, ok := ret.Get(0).(func() []*data.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch09/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: ctx, in func (_m *MockRegisterModel) Do(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch09/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch09/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch09/acme/internal/rest/register.go ================================================ package rest import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/data" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(ctx context.Context, in *data.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // set latency budget for this API subCtx, cancel := context.WithTimeout(request.Context(), 1500*time.Millisecond) defer cancel() // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // call the business logic using the request data and context id, err := h.register(subCtx, requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(ctx context.Context, requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(ctx, person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch09/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch09/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/logging" "github.com/gorilla/mux" ) // Config is the config for the REST package type Config interface { Logger() logging.Logger BindAddress() string } // New will create and initialize the server func New(cfg Config, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: cfg.BindAddress(), handlerGet: NewGetHandler(cfg, getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch09/acme/main.go ================================================ package main import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch09/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // build the exchanger exchanger := exchange.NewConverter(config.App) // build model layer getModel := get.NewGetter(config.App) listModel := list.NewLister(config.App) registerModel := register.NewRegisterer(config.App, exchanger) // start REST server server := rest.New(config.App, getModel, listModel, registerModel) server.Listen(ctx.Done()) } ================================================ FILE: ch09/fake.go ================================================ package ch09 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch10/01_intro_to_wire/01_simple/main.go ================================================ package main import ( "errors" "fmt" "github.com/google/wire" ) func main() { f := initializeDeps() result, err := f.GoFetch() fmt.Printf("Result: %s / %s", result, err) } // list of wire enabled dependencies var wireSet = wire.NewSet(ProvideFetcher) // Provider func ProvideFetcher() *Fetcher { return &Fetcher{} } // Object being "provided" type Fetcher struct { } func (f *Fetcher) GoFetch() (string, error) { return "", errors.New("not implemented yet") } ================================================ FILE: ch10/01_intro_to_wire/01_simple/wire.go ================================================ //+build wireinject package main import ( "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeDeps() *Fetcher { wire.Build(wireSet) return nil } ================================================ FILE: ch10/01_intro_to_wire/01_simple/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main // Injectors from wire.go: func initializeDeps() *Fetcher { fetcher := ProvideFetcher() return fetcher } ================================================ FILE: ch10/01_intro_to_wire/02_params/main.go ================================================ package main import ( "errors" "fmt" "github.com/google/wire" ) func main() { f := initializeDeps() result, err := f.GoFetch() fmt.Printf("Result: %s / %s", result, err) } // list of wire enabled dependencies var wireSet = wire.NewSet(ProvideFetcher, ProvideCache) // Providers func ProvideFetcher(cache *Cache) *Fetcher { return &Fetcher{ cache: cache, } } func ProvideCache() *Cache { return &Cache{} } type Cache struct{} func (c *Cache) Get(key string) (string, error) { return "", errors.New("not implemented yet") } func (c *Cache) Set(key string, value string) error { return errors.New("not implemented") } type Fetcher struct { cache *Cache } func (f *Fetcher) GoFetch() (string, error) { return "", errors.New("not implemented yet") } ================================================ FILE: ch10/01_intro_to_wire/02_params/wire.go ================================================ //+build wireinject package main import ( "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeDeps() *Fetcher { wire.Build(wireSet) return nil } ================================================ FILE: ch10/01_intro_to_wire/02_params/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main // Injectors from wire.go: func initializeDeps() *Fetcher { cache := ProvideCache() fetcher := ProvideFetcher(cache) return fetcher } ================================================ FILE: ch10/01_intro_to_wire/03_error/main.go ================================================ package main import ( "errors" "fmt" "github.com/google/wire" ) func main() { f, err := initializeDeps() if err != nil { panic(err.Error()) } result, err := f.GoFetch() fmt.Printf("Result: %s / %s", result, err) } // list of wire enabled dependencies var wireSet = wire.NewSet(ProvideFetcher, ProvideCache) // Providers func ProvideFetcher(cache *Cache) *Fetcher { return &Fetcher{ cache: cache, } } func ProvideCache() (*Cache, error) { cache := &Cache{} err := cache.Start() if err != nil { return nil, err } return cache, nil } type Cache struct{} func (c *Cache) Start() error { return errors.New("not implemented yet") } func (c *Cache) Get(key string) (string, error) { return "", errors.New("not implemented yet") } func (c *Cache) Set(key string, value string) error { return errors.New("not implemented") } type Fetcher struct { cache *Cache } func (f *Fetcher) GoFetch() (string, error) { return "", errors.New("not implemented yet") } ================================================ FILE: ch10/01_intro_to_wire/03_error/wire.go ================================================ //+build wireinject package main import ( "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeDeps() (*Fetcher, error) { wire.Build(wireSet) return nil, nil } ================================================ FILE: ch10/01_intro_to_wire/03_error/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main // Injectors from wire.go: func initializeDeps() (*Fetcher, error) { cache, err := ProvideCache() if err != nil { return nil, err } fetcher := ProvideFetcher(cache) return fetcher, nil } ================================================ FILE: ch10/01_intro_to_wire/04_without_pset/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server, err := initializeServer() if err != nil { os.Exit(-1) } server.Listen(ctx.Done()) } ================================================ FILE: ch10/01_intro_to_wire/04_without_pset/wire.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeServer() (*rest.Server, error) { wire.Build( // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), exchange.NewConverter, // *get.Getter wire.Bind(new(get.Config), &config.Config{}), get.NewGetter, // *list.Lister wire.Bind(new(list.Config), &config.Config{}), list.NewLister, // *register.Registerer wire.Bind(new(register.Config), &config.Config{}), wire.Bind(new(register.Exchanger), &exchange.Converter{}), register.NewRegisterer, // *rest.Server wire.Bind(new(rest.Config), &config.Config{}), wire.Bind(new(rest.GetModel), &get.Getter{}), wire.Bind(new(rest.ListModel), &list.Lister{}), wire.Bind(new(rest.RegisterModel), ®ister.Registerer{}), rest.New, ) return nil, nil } ================================================ FILE: ch10/01_intro_to_wire/04_without_pset/wire_gen.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" ) // Injectors from wire.go: func initializeServer() (*rest.Server, error) { configConfig, err := config.Load() if err != nil { return nil, err } getter := get.NewGetter(configConfig) lister := list.NewLister(configConfig) converter := exchange.NewConverter(configConfig) registerer := register.NewRegisterer(configConfig, converter) server := rest.New(configConfig, getter, lister, registerer) return server, nil } ================================================ FILE: ch10/02_advantages/01_dig/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "go.uber.org/dig" ) func main() { // bind stop channel to context ctx := context.Background() // build DIG container container := BuildContainer() // start REST server err := container.Invoke(func(server *rest.Server) { server.Listen(ctx.Done()) }) if err != nil { os.Exit(-1) } } func BuildContainer() *dig.Container { container := dig.New() container.Provide(config.Load) container.Provide(exchange.NewConverter) container.Provide(get.NewGetter) container.Provide(list.NewLister) container.Provide(register.NewRegisterer) container.Provide(rest.New) return container } ================================================ FILE: ch10/02_advantages/02_instantiation_order/handler.go ================================================ package main import ( "net/http" ) func NewGetPersonHandler(model *GetPersonModel) *GetPersonHandler { return &GetPersonHandler{ model: model, } } type GetPersonHandler struct { model *GetPersonModel } func (g *GetPersonHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { response.WriteHeader(http.StatusInternalServerError) response.Write([]byte(`not implemented yet`)) } ================================================ FILE: ch10/02_advantages/02_instantiation_order/injectors.go ================================================ //+build wireinject package main import ( "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeDeps() *GetPersonHandler { wire.Build(wireSet) return nil } ================================================ FILE: ch10/02_advantages/02_instantiation_order/main.go ================================================ package main func main() { // something fantastic goes here! } ================================================ FILE: ch10/02_advantages/02_instantiation_order/model.go ================================================ package main import ( "database/sql" "errors" ) func NewGetPersonModel(db *sql.DB) *GetPersonModel { return &GetPersonModel{ db: db, } } type GetPersonModel struct { db *sql.DB } func (g *GetPersonModel) LoadByID(ID int) (*Person, error) { return nil, errors.New("not implemented yet") } type Person struct { Name string } ================================================ FILE: ch10/02_advantages/02_instantiation_order/providers.go ================================================ package main import ( "database/sql" "github.com/google/wire" ) func ProvideHandler(model *GetPersonModel) *GetPersonHandler { return &GetPersonHandler{ model: model, } } func ProvideModel(db *sql.DB) *GetPersonModel { return &GetPersonModel{ db: db, } } func ProvideDatabase() *sql.DB { return &sql.DB{} } var wireSet = wire.NewSet( ProvideHandler, ProvideModel, ProvideDatabase, ) ================================================ FILE: ch10/02_advantages/02_instantiation_order/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main // Injectors from injectors.go: func initializeDeps() *GetPersonHandler { db := ProvideDatabase() getPersonModel := ProvideModel(db) getPersonHandler := ProvideHandler(getPersonModel) return getPersonHandler } ================================================ FILE: ch10/03_applying/01_before_config/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // load config cfg, err := config.Load(config.DefaultEnvVar) if err != nil { os.Exit(-1) } // build the exchanger exchanger := exchange.NewConverter(cfg) // build model layer getModel := get.NewGetter(cfg) listModel := list.NewLister(cfg) registerModel := register.NewRegisterer(cfg, exchanger) // start REST server server := rest.New(cfg, getModel, listModel, registerModel) server.Listen(ctx.Done()) } ================================================ FILE: ch10/03_applying/02_after_config/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" ) func main() { // bind stop channel to context ctx := context.Background() // load config cfg, err := initializeConfig() if err != nil { os.Exit(-1) } // build the exchanger exchanger := exchange.NewConverter(cfg) // build model layer getModel := get.NewGetter(cfg) listModel := list.NewLister(cfg) registerModel := register.NewRegisterer(cfg, exchanger) // start REST server server := rest.New(cfg, getModel, listModel, registerModel) server.Listen(ctx.Done()) } ================================================ FILE: ch10/03_applying/02_after_config/wire.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeConfig() (*config.Config, error) { wire.Build(config.Load) return nil, nil } ================================================ FILE: ch10/03_applying/02_after_config/wire_gen.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" ) // Injectors from wire.go: func initializeConfig() (*config.Config, error) { configConfig, err := config.Load() if err != nil { return nil, err } return configConfig, nil } ================================================ FILE: ch10/03_applying/03_after_exchange/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) func main() { // bind stop channel to context ctx := context.Background() // load config cfg, err := initializeConfig() if err != nil { os.Exit(-1) } // build the exchanger exchanger, err := initializeExchanger() if err != nil { os.Exit(-1) } // build model layer getModel := get.NewGetter(cfg) listModel := list.NewLister(cfg) registerModel := register.NewRegisterer(cfg, exchanger) // start REST server server := rest.New(cfg, getModel, listModel, registerModel) server.Listen(ctx.Done()) } // List of wire enabled objects var wireSet = wire.NewSet( // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), exchange.NewConverter, ) ================================================ FILE: ch10/03_applying/03_after_exchange/wire.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeConfig() (*config.Config, error) { wire.Build(wireSet) return nil, nil } func initializeExchanger() (*exchange.Converter, error) { wire.Build(wireSet) return nil, nil } ================================================ FILE: ch10/03_applying/04_after_model/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) func main() { // bind stop channel to context ctx := context.Background() // load config cfg, err := initializeConfig() if err != nil { os.Exit(-1) } // build model layer getModel, _ := initializeGetter() listModel, _ := initializeLister() registerModel, _ := initializeRegisterer() // start REST server server := rest.New(cfg, getModel, listModel, registerModel) server.Listen(ctx.Done()) } // List of wire enabled objects var wireSet = wire.NewSet( // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), exchange.NewConverter, // *get.Getter wire.Bind(new(get.Config), &config.Config{}), get.NewGetter, // *list.Lister wire.Bind(new(list.Config), &config.Config{}), list.NewLister, // *register.Registerer wire.Bind(new(register.Config), &config.Config{}), wire.Bind(new(register.Exchanger), &exchange.Converter{}), register.NewRegisterer, ) ================================================ FILE: ch10/03_applying/04_after_model/wire.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeConfig() (*config.Config, error) { wire.Build(wireSet) return nil, nil } func initializeGetter() (*get.Getter, error) { wire.Build(wireSet) return nil, nil } func initializeLister() (*list.Lister, error) { wire.Build(wireSet) return nil, nil } func initializeRegisterer() (*register.Registerer, error) { wire.Build(wireSet) return nil, nil } ================================================ FILE: ch10/03_applying/05_after_rest/main.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server, err := initializeServer() if err != nil { os.Exit(-1) } server.Listen(ctx.Done()) } // List of wire enabled objects var wireSet = wire.NewSet( // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), exchange.NewConverter, // *get.Getter wire.Bind(new(get.Config), &config.Config{}), get.NewGetter, // *list.Lister wire.Bind(new(list.Config), &config.Config{}), list.NewLister, // *register.Registerer wire.Bind(new(register.Config), &config.Config{}), wire.Bind(new(register.Exchanger), &exchange.Converter{}), register.NewRegisterer, // *rest.Server wire.Bind(new(rest.Config), &config.Config{}), wire.Bind(new(rest.GetModel), &get.Getter{}), wire.Bind(new(rest.ListModel), &list.Lister{}), wire.Bind(new(rest.RegisterModel), ®ister.Registerer{}), rest.New, ) ================================================ FILE: ch10/03_applying/05_after_rest/wire.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeServer() (*rest.Server, error) { wire.Build(wireSet) return nil, nil } ================================================ FILE: ch10/03_applying/05_after_rest/wire_gen.go ================================================ //+build ignore // Code above this line should be ignored as it's not part of the example // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" ) // Injectors from wire.go: func initializeServer() (*rest.Server, error) { configConfig, err := config.Load() if err != nil { return nil, err } getter := get.NewGetter(configConfig) lister := list.NewLister(configConfig) converter := exchange.NewConverter(configConfig) registerer := register.NewRegisterer(configConfig, converter) server := rest.New(configConfig, getter, lister, registerer) return server, nil } ================================================ FILE: ch10/03_applying/06_build_tag.go ================================================ //+build myTag package main import ( "fmt" ) func sayHello() { fmt.Println("Hello World!") } ================================================ FILE: ch10/03_applying/06_build_tag_inverse.go ================================================ //+build !myTag package main import ( "fmt" ) func sayHello() { fmt.Println("Hello Universe!") } ================================================ FILE: ch10/03_applying/06_main.go ================================================ package main func main() { sayHello() } ================================================ FILE: ch10/04_disadvantages/01_complexity/main.go ================================================ package main import ( "encoding/json" "io/ioutil" "go.uber.org/dig" ) const ( configFile = "config.json" ) func main() { c := dig.New() err := c.Provide(func() (*Config, error) { out := &Config{} bytes, err := ioutil.ReadFile(configFile) if err != nil { return nil, err } err = json.Unmarshal(bytes, out) if err != nil { return nil, err } return out, nil }) if err != nil { panic(err) } err = c.Provide(func(cfg *Config) *Logger { return &Logger{level: cfg.Level} }) if err != nil { panic(err) } err = c.Provide(func(logger *Logger) *Server { return &Server{logger: logger} }) if err != nil { panic(err) } err = c.Invoke(func(server *Server) { server.Listen() }) if err != nil { panic(err) } } type Config struct { Level string } type Logger struct { level string } func (l *Logger) Debug(msg string, args ...interface{}) { // not implemented } func (l *Logger) Warn(msg string, args ...interface{}) { // not implemented } func (l *Logger) Error(msg string, args ...interface{}) { // not implemented } type Server struct { logger *Logger } func (s *Server) Listen() { // not implemented } ================================================ FILE: ch10/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "fmt" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string // environmental dependencies logger logging.Logger } // Logger returns a reference to the singleton logger func (c *Config) Logger() logging.Logger { if c.logger == nil { c.logger = &logging.LoggerStdOut{} } return c.logger } // RegistrationBasePrice returns the base price for registrations func (c *Config) RegistrationBasePrice() float64 { return c.BasePrice } // DataDSN returns the DSN func (c *Config) DataDSN() string { return c.DSN } // ExchangeBaseURL returns the Base URL from which we can load exchange rates func (c *Config) ExchangeBaseURL() string { return c.ExchangeRateBaseURL } // ExchangeAPIKey returns the DSN func (c *Config) ExchangeAPIKey() string { return c.ExchangeRateAPIKey } // BindAddress returns the host and port this service should bind to func (c *Config) BindAddress() string { return c.Address } // Load returns the config loaded from environment func Load() (*Config, error) { filename, found := os.LookupEnv(DefaultEnvVar) if !found { err := fmt.Errorf("failed to locate file specified by %s", DefaultEnvVar) logging.L.Error(err.Error()) return nil, err } cfg, err := load(filename) if err != nil { logging.L.Error("failed to load config with err %s", err) return nil, err } return cfg, nil } func load(filename string) (*Config, error) { out := &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { logging.L.Error("failed to read config file. err: %s", err) return nil, err } err = json.Unmarshal(bytes, out) if err != nil { logging.L.Error("failed to parse config file. err : %s", err) return nil, err } return out, nil } ================================================ FILE: ch10/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: nil, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { result, resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, result, scenario.desc) }) } } ================================================ FILE: ch10/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // Logger is our standard interface type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } // L is the global instance of the logger var L = &LoggerStdOut{} // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch10/acme/internal/modules/data/dao.go ================================================ package data import ( "context" "database/sql" "time" ) // NewDAO will initialize the database connection pool (if not already done) and return a data access object which // can be used to interact with the database func NewDAO(cfg Config) *DAO { // initialize the db connection pool _, _ = getDB(cfg) return &DAO{ cfg: cfg, } } // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config // Tracker is an optional query timer Tracker QueryTracker } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { // track processing time defer d.getTracker().Track("Load", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { d.cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) LoadAll(ctx context.Context) ([]*Person, error) { // track processing time defer d.getTracker().Track("LoadAll", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { d.cfg.Logger().Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func (d *DAO) Save(ctx context.Context, in *Person) (int, error) { // track processing time defer d.getTracker().Track("Save", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { d.cfg.Logger().Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { d.cfg.Logger().Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } func (d *DAO) getTracker() QueryTracker { if d.Tracker == nil { d.Tracker = &noopTracker{} } return d.Tracker } ================================================ FILE: ch10/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // Logger returns a reference to the logger Logger() logging.Logger // DataDSN returns the data source name DataDSN() string } var getDB = func(cfg Config) (*sql.DB, error) { if db == nil { var err error db, err = sql.Open("mysql", cfg.DataDSN()) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch10/acme/internal/modules/data/data_test.go ================================================ package data import ( "context" "database/sql" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func(_ Config) (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func(_ Config) (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) results, err := dao.LoadAll(ctx) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) result, err := dao.Load(ctx, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } type testConfig struct{} // Logger implements Config func (t *testConfig) Logger() logging.Logger { return logging.LoggerStdOut{} } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } ================================================ FILE: ch10/acme/internal/modules/data/tracker.go ================================================ package data import ( "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" ) // QueryTracker is an interface to track query timing type QueryTracker interface { // Track will record/out the time a query took by calculating time.Now().Sub(start) Track(key string, start time.Time) } // NO-OP implementation of QueryTracker type noopTracker struct{} // Track implements QueryTracker func (_ *noopTracker) Track(_ string, _ time.Time) { // intentionally does nothing } // NewLogTracker returns a Tracker that outputs tracking data to log func NewLogTracker(logger logging.Logger) *LogTracker { return &LogTracker{ logger: logger, } } // LogTracker implements QueryTracker and outputs to the supplied logger type LogTracker struct { logger logging.Logger } // Track implements QueryTracker func (l *LogTracker) Track(key string, start time.Time) { l.logger.Info("[%s] Timing: %s\n", key, time.Now().Sub(start).String()) } ================================================ FILE: ch10/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // NewConverter creates and initializes the converter func NewConverter(cfg Config) *Converter { return &Converter{ cfg: cfg, } } // Config is the config for Converter type Config interface { Logger() logging.Logger ExchangeBaseURL() string ExchangeAPIKey() string } // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct { cfg Config } // Exchange will perform the conversion func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, c.cfg.ExchangeBaseURL(), c.cfg.ExchangeAPIKey(), currency) // perform request req, err := http.NewRequest("GET", url, nil) if err != nil { c.logger().Warn("[exchange] failed to create request. err: %s", err) return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request response, err := http.DefaultClient.Do(req) if err != nil { c.logger().Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) c.logger().Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) c.logger().Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { c.logger().Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { c.logger().Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } func (c *Converter) logger() logging.Logger { return c.cfg.Logger() } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch10/acme/internal/modules/exchange/converter_ext_bounday_test.go ================================================ // +build external package exchange import ( "context" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExternalBoundaryTest(t *testing.T) { // define the config cfg, err := config.Load() require.NoError(t, err) // create a converter to test converter := NewConverter(cfg) // fetch from the server response, err := converter.loadRateFromServer(context.Background(), "AUD") require.NotNil(t, response) require.NoError(t, err) // parse the response resultRate, err := converter.extractRate(response, "AUD") require.NoError(t, err) // validate the result assert.True(t, resultRate > 0) } ================================================ FILE: ch10/acme/internal/modules/exchange/converter_int_bounday_test.go ================================================ package exchange import ( "context" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/stretchr/testify/assert" ) func TestInternalBoundaryTest(t *testing.T) { // start our test server server := httptest.NewServer(&happyExchangeRateService{}) defer server.Close() // define the config cfg := &testConfig{ baseURL: server.URL, apiKey: "", } // create a converter to test converter := NewConverter(cfg) resultRate, resultErr := converter.Exchange(context.Background(), 100.00, "AUD") // validate the result assert.Equal(t, 101.01, resultRate) assert.NoError(t, resultErr) } type happyExchangeRateService struct{} // ServeHTTP implements http.Handler func (*happyExchangeRateService) ServeHTTP(response http.ResponseWriter, request *http.Request) { payload := []byte(` { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } }`) response.Write(payload) } // test implementation of Config type testConfig struct { baseURL string apiKey string } // Logger implements Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // ExchangeBaseURL implements Config func (t *testConfig) ExchangeBaseURL() string { return t.baseURL } // ExchangeAPIKey implements Config func (t *testConfig) ExchangeAPIKey() string { return t.apiKey } ================================================ FILE: ch10/acme/internal/modules/get/get.go ================================================ package get import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // NewGetter creates and initializes a Getter func NewGetter(cfg Config) *Getter { return &Getter{ cfg: cfg, } } // Config is the configuration for Getter type Config interface { Logger() logging.Logger DataDSN() string } // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { cfg Config data myLoader } // Do will perform the get func (g *Getter) Do(ID int) (*data.Person, error) { // load person from the data layer person, err := g.getLoader().Load(context.TODO(), ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return person, err } func (g *Getter) getLoader() myLoader { if g.data == nil { g.data = data.NewDAO(g.cfg) } return g.data } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { Load(ctx context.Context, ID int) (*data.Person, error) } ================================================ FILE: ch10/acme/internal/modules/get/go_test.go ================================================ package get import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockResult := &data.Person{ ID: 1234, FullName: "Doug", } mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(mockResult, nil).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, data.ErrNotFound).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, errors.New("something failed")).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch10/acme/internal/modules/get/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package get import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // Load provides a mock function with given fields: ctx, ID func (_m *mockMyLoader) Load(ctx context.Context, ID int) (*data.Person, error) { ret := _m.Called(ctx, ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(context.Context, int) *data.Person); ok { r0 = rf(ctx, ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { r1 = rf(ctx, ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch10/acme/internal/modules/list/list.go ================================================ package list import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // NewLister creates and initializes a Lister func NewLister(cfg Config) *Lister { return &Lister{ cfg: cfg, } } // Config is the config for Lister type Config interface { Logger() logging.Logger DataDSN() string } // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { cfg Config data myLoader } // Exchange will load the people from the data layer func (l *Lister) Do() ([]*data.Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return people, nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := l.getLoader().LoadAll(context.TODO()) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } func (l *Lister) getLoader() myLoader { if l.data == nil { l.data = data.NewDAO(l.cfg) // temporarily add a log tracker l.data.(*data.DAO).Tracker = data.NewLogTracker(l.cfg.Logger()) } return l.data } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { LoadAll(ctx context.Context) ([]*data.Person, error) } ================================================ FILE: ch10/acme/internal/modules/list/list_test.go ================================================ package list import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // configure the mock loader mockResult := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(mockResult, nil).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_noResults(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, data.ErrNotFound).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_error(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, errors.New("something failed")).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch10/acme/internal/modules/list/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package list import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // LoadAll provides a mock function with given fields: ctx func (_m *mockMyLoader) LoadAll(ctx context.Context) ([]*data.Person, error) { ret := _m.Called(ctx) var r0 []*data.Person if rf, ok := ret.Get(0).(func(context.Context) []*data.Person); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch10/acme/internal/modules/register/mock_my_saver_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package register import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMySaver is an autogenerated mock type for the mySaver type type mockMySaver struct { mock.Mock } // Save provides a mock function with given fields: ctx, in func (_m *mockMySaver) Save(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch10/acme/internal/modules/register/register.go ================================================ package register import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // NewRegisterer creates and initializes a Registerer func NewRegisterer(cfg Config, exchanger Exchanger) *Registerer { return &Registerer{ cfg: cfg, exchanger: exchanger, } } // Exchanger will convert from one currency to another type Exchanger interface { // Exchange will perform the conversion Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) } // Config is the configuration for the Registerer type Config interface { Logger() logging.Logger RegistrationBasePrice() float64 DataDSN() string } // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { cfg Config exchanger Exchanger data mySaver } // Do is API for this struct func (r *Registerer) Do(ctx context.Context, in *data.Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { r.logger().Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(ctx, in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(ctx, in, price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *data.Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { price, err := r.exchanger.Exchange(ctx, r.cfg.RegistrationBasePrice(), currency) if err != nil { r.logger().Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return r.getSaver().Save(ctx, person) } func (r *Registerer) getSaver() mySaver { if r.data == nil { r.data = data.NewDAO(r.cfg) } return r.data } func (r *Registerer) logger() logging.Logger { return r.cfg.Logger() } //go:generate mockery -name=mySaver -case underscore -testonly -inpkg -note @generated type mySaver interface { Save(ctx context.Context, in *data.Person) (int, error) } ================================================ FILE: ch10/acme/internal/modules/register/register_test.go ================================================ package register import ( "context" "errors" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // configure the mock saver mockResult := 888 mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(mockResult, nil).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) assert.True(t, mockSaver.AssertExpectations(t)) } func TestRegisterer_Do_error(t *testing.T) { // configure the mock saver mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(0, errors.New("something failed")).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &data.Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) assert.True(t, mockSaver.AssertExpectations(t)) } // Stub implementation of Config type testConfig struct{} // Logger implement Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // RegistrationBasePrice implement Config func (t *testConfig) RegistrationBasePrice() float64 { return 12.34 } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } type stubExchanger struct{} // Exchange implements Exchanger func (s stubExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { return 12.34, nil } ================================================ FILE: ch10/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*data.Person, error) } // GetConfig is the config for the Get Handler type GetConfig interface { Logger() logging.Logger } // NewGetHandler is the constructor for GetHandler func NewGetHandler(cfg GetConfig, model GetModel) *GetHandler { return &GetHandler{ cfg: cfg, getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { cfg GetConfig getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") h.cfg.Logger().Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) h.cfg.Logger().Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *data.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch10/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &data.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(&testConfig{}, mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } type testConfig struct { } func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } func (*testConfig) BindAddress() string { return "0.0.0.0:0" } ================================================ FILE: ch10/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*data.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*data.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch10/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*data.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data output := []*data.Person{} mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch10/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*data.Person, error) { ret := _m.Called(ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(int) *data.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch10/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*data.Person, error) { ret := _m.Called() var r0 []*data.Person if rf, ok := ret.Get(0).(func() []*data.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch10/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: ctx, in func (_m *MockRegisterModel) Do(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch10/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch10/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch10/acme/internal/rest/register.go ================================================ package rest import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/data" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(ctx context.Context, in *data.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // set latency budget for this API subCtx, cancel := context.WithTimeout(request.Context(), 1500*time.Millisecond) defer cancel() // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // call the business logic using the request data and context id, err := h.register(subCtx, requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(ctx context.Context, requestPayload *registerRequest) (int, error) { person := &data.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(ctx, person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch10/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch10/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/logging" "github.com/gorilla/mux" ) // Config is the config for the REST package type Config interface { Logger() logging.Logger BindAddress() string } // New will create and initialize the server func New(cfg Config, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: cfg.BindAddress(), handlerGet: NewGetHandler(cfg, getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch10/acme/main.go ================================================ package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server, err := initializeServer() if err != nil { os.Exit(-1) } server.Listen(ctx.Done()) } // List of wire enabled objects var wireSetWithoutConfig = wire.NewSet( // *exchange.Converter exchange.NewConverter, // *get.Getter get.NewGetter, // *list.Lister list.NewLister, // *register.Registerer wire.Bind(new(register.Exchanger), &exchange.Converter{}), register.NewRegisterer, // *rest.Server wire.Bind(new(rest.GetModel), &get.Getter{}), wire.Bind(new(rest.ListModel), &list.Lister{}), wire.Bind(new(rest.RegisterModel), ®ister.Registerer{}), rest.New, ) var wireSet = wire.NewSet( wireSetWithoutConfig, // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), // *get.Getter wire.Bind(new(get.Config), &config.Config{}), // *list.Lister wire.Bind(new(list.Config), &config.Config{}), // *register.Registerer wire.Bind(new(register.Config), &config.Config{}), // *rest.Server wire.Bind(new(rest.Config), &config.Config{}), ) ================================================ FILE: ch10/acme/main_test.go ================================================ package main import ( "bytes" "context" "errors" "fmt" "net" "net/http" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegister(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request payload := bytes.NewBufferString(` { "fullName": "Bob", "phone": "0123456789", "currency": "AUD" } `) req, err := http.NewRequest("POST", serverAddress+"/person/register", payload) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusCreated, resp.StatusCode) assert.NotEmpty(t, resp.Header.Get("Location")) } func TestGet(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request req, err := http.NewRequest("GET", serverAddress+"/person/1/", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusOK, resp.StatusCode) } func TestList(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request req, err := http.NewRequest("GET", serverAddress+"/person/list", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusOK, resp.StatusCode) } func startTestServer(t *testing.T, ctx context.Context) string { // load the standard config (from the ENV) cfg, err := config.Load() require.NoError(t, err) // get a free port (so tests can run concurrently) port, err := getFreePort() require.NoError(t, err) // override config port with free one cfg.Address = net.JoinHostPort("0.0.0.0", port) // start the test server on a random port go func() { // start REST server server := initializeServerCustomConfig(cfg, cfg, cfg, cfg, cfg) server.Listen(ctx.Done()) }() // give the server a chance to start <-time.After(100 * time.Millisecond) // return the address of the test server return "http://" + cfg.Address } func getFreePort() (string, error) { for attempt := 0; attempt <= 10; attempt++ { addr := net.JoinHostPort("", "0") listener, err := net.Listen("tcp", addr) if err != nil { continue } port, err := getPort(listener.Addr()) if err != nil { continue } // close/free the port tcpListener := listener.(*net.TCPListener) cErr := tcpListener.Close() if cErr == nil { file, fErr := tcpListener.File() if fErr == nil { // ignore any errors cleaning up the file _ = file.Close() } return port, nil } } return "", errors.New("no free ports") } func getPort(addr fmt.Stringer) (string, error) { actualAddress := addr.String() _, port, err := net.SplitHostPort(actualAddress) return port, err } ================================================ FILE: ch10/acme/wire.go ================================================ //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeServer() (*rest.Server, error) { wire.Build(wireSet) return nil, nil } func initializeServerCustomConfig(_ exchange.Config, _ get.Config, _ list.Config, _ register.Config, _ rest.Config) *rest.Server { wire.Build(wireSetWithoutConfig) return nil } ================================================ FILE: ch10/acme/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch10/acme/internal/rest" ) // Injectors from wire.go: func initializeServer() (*rest.Server, error) { configConfig, err := config.Load() if err != nil { return nil, err } getter := get.NewGetter(configConfig) lister := list.NewLister(configConfig) converter := exchange.NewConverter(configConfig) registerer := register.NewRegisterer(configConfig, converter) server := rest.New(configConfig, getter, lister, registerer) return server, nil } func initializeServerCustomConfig(exchangeConfig exchange.Config, getConfig get.Config, listConfig list.Config, registerConfig register.Config, restConfig rest.Config) *rest.Server { getter := get.NewGetter(getConfig) lister := list.NewLister(listConfig) converter := exchange.NewConverter(exchangeConfig) registerer := register.NewRegisterer(registerConfig, converter) server := rest.New(restConfig, getter, lister, registerer) return server } ================================================ FILE: ch10/fake.go ================================================ package ch10 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: ch11/01_di_induced_damage/01_long_param/01_long_param.go ================================================ package long_param import ( "net/http" "time" ) func NewMyHandler(logger Logger, stats Instrumentation, parser Parser, formatter Formatter, limiter RateLimiter, cache Cache, db Datastore) *MyHandler { return &MyHandler{ // code removed } } // MyHandler does something fantastic type MyHandler struct { // code removed } func (m *MyHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // code removed } // Logger logs stuff type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Info(message string, args ...interface{}) Debug(message string, args ...interface{}) } // Instrumentation records the performances and events type Instrumentation interface { Count(key string, value int) Duration(key string, start time.Time) } // Parse will extract details from the request type Parser interface { Extract(req *http.Request) (int, error) } // Formatter will build the output type Formatter interface { Format(resp http.ResponseWriter, data []byte) error } // RateLimiter limits how many concurrent requests we can make or process type RateLimiter interface { Acquire() Release() } // Datastore will store/retrieve data in a permanent type Datastore interface { Load(ID int) ([]byte, error) } // Cache will store/retrieve data in a fast way type Cache interface { Store(key string, data []byte) Get(key string) ([]byte, error) } ================================================ FILE: ch11/01_di_induced_damage/02_long_param/01_long_param.go ================================================ package long_param import ( "net/http" "time" ) func NewMyHandler(logger Logger, stats Instrumentation, parser Parser, formatter Formatter, limiter RateLimiter, loader Loader) *MyHandler { return &MyHandler{ // code removed } } // MyHandler does something fantastic type MyHandler struct { // code removed } func (m *MyHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // code removed } // Logger logs stuff type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Info(message string, args ...interface{}) Debug(message string, args ...interface{}) } // Instrumentation records the performances and events type Instrumentation interface { Count(key string, value int) Duration(key string, start time.Time) } // Parse will extract details from the request type Parser interface { Extract(req *http.Request) (int, error) } // Formatter will build the output type Formatter interface { Format(resp http.ResponseWriter, data []byte) error } // RateLimiter limits how many concurrent requests we can make or process type RateLimiter interface { Acquire() Release() } // Loader is responsible for loading the data type Loader interface { Load(ID int) ([]byte, error) } ================================================ FILE: ch11/01_di_induced_damage/03_long_param/01_long_param.go ================================================ package long_param import ( "net/http" "time" ) func NewMyHandler(config Config, parser Parser, formatter Formatter, limiter RateLimiter, loader Loader) *MyHandler { return &MyHandler{ config: config, parser: parser, formatter: formatter, limiter: limiter, loader: loader, } } // MyHandler does something fantastic type MyHandler struct { config Config parser Parser formatter Formatter limiter RateLimiter loader Loader } func (m *MyHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { ID, err := m.parser.Extract(request) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } data, err := m.loader.Load(ID) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } err = m.formatter.Format(response, data) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } } // Config combines environmental concerns like logging and instrumentation with any other config type Config interface { Logger() Logger Instrumentation() Instrumentation } // Logger logs stuff type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Info(message string, args ...interface{}) Debug(message string, args ...interface{}) } // Instrumentation records the performances and events type Instrumentation interface { Count(key string, value int) Duration(key string, start time.Time) } // Parse will extract details from the request type Parser interface { Extract(req *http.Request) (int, error) } // Formatter will build the output type Formatter interface { Format(resp http.ResponseWriter, data []byte) error } // RateLimiter limits how many concurrent requests we can make or process type RateLimiter interface { Acquire() Release() } // Loader is responsible for loading the data type Loader interface { Load(ID int) ([]byte, error) } ================================================ FILE: ch11/01_di_induced_damage/04_long_param/01_long_param.go ================================================ package long_param import ( "net/http" "time" ) func NewFancyFormatHandler(config Config, parser Parser, limiter RateLimiter, loader Loader) *FancyFormatHandler { return &FancyFormatHandler{ &MyHandler{ config: config, formatter: &FancyFormatter{}, parser: parser, limiter: limiter, loader: loader, }, } } // FancyFormatHandler does something fancy type FancyFormatHandler struct { *MyHandler } // MyHandler does something fantastic type MyHandler struct { config Config parser Parser formatter Formatter limiter RateLimiter loader Loader } func (m *MyHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { ID, err := m.parser.Extract(request) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } data, err := m.loader.Load(ID) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } err = m.formatter.Format(response, data) if err != nil { response.WriteHeader(http.StatusInternalServerError) return } } // Config combines environmental concerns like logging and instrumentation with any other config type Config interface { Logger() Logger Instrumentation() Instrumentation } // Logger logs stuff type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Info(message string, args ...interface{}) Debug(message string, args ...interface{}) } // Instrumentation records the performances and events type Instrumentation interface { Count(key string, value int) Duration(key string, start time.Time) } // Parse will extract details from the request type Parser interface { Extract(req *http.Request) (int, error) } // Formatter will build the output type Formatter interface { Format(resp http.ResponseWriter, data []byte) error } // FancyFormatter Implements Formatter type FancyFormatter struct{} func (f *FancyFormatter) Format(response http.ResponseWriter, data []byte) error { // does something fancy with the data _, err := response.Write([]byte(`something fancy!`)) return err } // RateLimiter limits how many concurrent requests we can make or process type RateLimiter interface { Acquire() Release() } // Loader is responsible for loading the data type Loader interface { Load(ID int) ([]byte, error) } ================================================ FILE: ch11/01_di_induced_damage/04_long_param/01_long_param_test.go ================================================ package long_param import ( "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" ) func TestNewFancyFormatHandler(t *testing.T) { // inputs config := &stubConfig{} parser := &stubParser{} limiter := &stubRateLimiter{} loader := &stubLoader{} // create the handler fancyHandler := NewFancyFormatHandler(config, parser, limiter, loader) // call with fake HTTP request response := httptest.NewRecorder() request := &http.Request{} fancyHandler.ServeHTTP(response, request) // validate result assert.Equal(t, http.StatusOK, response.Code) assert.Equal(t, "something fancy!", response.Body.String()) } // define some mock implementations so that our test can run type stubConfig struct{} func (s *stubConfig) Logger() Logger { return &stubLogger{} } func (s *stubConfig) Instrumentation() Instrumentation { return &stubInstrumentation{} } type stubLogger struct{} func (s *stubLogger) Error(message string, args ...interface{}) { // do nothing } func (s *stubLogger) Warn(message string, args ...interface{}) { // do nothing } func (s *stubLogger) Info(message string, args ...interface{}) { // do nothing } func (s *stubLogger) Debug(message string, args ...interface{}) { // do nothing } type stubInstrumentation struct{} func (s *stubInstrumentation) Count(key string, value int) { // do nothing } func (s *stubInstrumentation) Duration(key string, start time.Time) { // do nothing } type stubParser struct{} func (s *stubParser) Extract(req *http.Request) (int, error) { return 1, nil } type stubRateLimiter struct{} func (s *stubRateLimiter) Acquire() { // do nothing } func (s *stubRateLimiter) Release() { // do nothing } type stubLoader struct{} func (s *stubLoader) Load(ID int) ([]byte, error) { return []byte(`some data`), nil } ================================================ FILE: ch11/01_di_induced_damage/05_inject_sql/01_interface.go ================================================ package inject_vs_config import ( "context" "database/sql" ) type Connection interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } ================================================ FILE: ch11/01_di_induced_damage/06_inject_sql/01_interface.go ================================================ package data import ( "context" ) type Database interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) Row QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) } type Row interface { Scan(dest ...interface{}) error } type Rows interface { Scan(dest ...interface{}) error Close() error Next() bool } type Result interface { LastInsertId() (int64, error) RowsAffected() (int64, error) } ================================================ FILE: ch11/01_di_induced_damage/06_inject_sql/02_implementation.go ================================================ package data import ( "context" "database/sql" ) type DatabaseImpl struct { db *sql.DB } func (c *DatabaseImpl) QueryRowContext(ctx context.Context, query string, args ...interface{}) Row { return c.db.QueryRowContext(ctx, query, args...) } func (c *DatabaseImpl) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { return c.db.QueryContext(ctx, query, args...) } func (c *DatabaseImpl) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { return c.db.ExecContext(ctx, query, args...) } type RowImpl struct { row *sql.Row } func (r *RowImpl) Scan(dest ...interface{}) error { return r.row.Scan(dest...) } type RowsImpl struct { rows *sql.Rows } func (r RowsImpl) Scan(dest ...interface{}) error { return r.rows.Scan(dest...) } func (r RowsImpl) Close() error { return r.rows.Close() } func (r RowsImpl) Next() bool { return r.rows.Next() } type ResultImpl struct { result sql.Result } func (r *ResultImpl) LastInsertId() (int64, error) { return r.result.LastInsertId() } func (r *ResultImpl) RowsAffected() (int64, error) { return r.result.RowsAffected() } ================================================ FILE: ch11/01_di_induced_damage/06_inject_sql/02_implementation_test.go ================================================ package data import ( "testing" "github.com/stretchr/testify/assert" ) func TestImplements(t *testing.T) { assert.Implements(t, (*Database)(nil), &DatabaseImpl{}) assert.Implements(t, (*Row)(nil), &RowImpl{}) assert.Implements(t, (*Rows)(nil), &RowsImpl{}) assert.Implements(t, (*Result)(nil), &ResultImpl{}) } ================================================ FILE: ch11/01_di_induced_damage/06_inject_sql/dao.go ================================================ package data import ( "context" "database/sql" "fmt" "os" "time" ) // NewDAO will initialize the database connection pool (if not already done) and return a data access object which // can be used to interact with the database func NewDAO(cfg Config) *DAO { // initialize the db connection pool _, _ = getDB(cfg) return &DAO{ cfg: cfg, } } // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { db, err := getDB(d.cfg) if err != nil { fmt.Fprintf(os.Stderr, "failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { fmt.Fprintf(os.Stderr, "failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } fmt.Fprintf(os.Stderr, "failed to convert query result. err: %s", err) return nil, err } return out, nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) LoadAll(ctx context.Context) ([]*Person, error) { db, err := getDB(d.cfg) if err != nil { fmt.Fprintf(os.Stderr, "failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { fmt.Fprintf(os.Stderr, "failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { fmt.Fprintf(os.Stderr, "no people found in the database.") return nil, ErrNotFound } return out, nil } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func (d *DAO) Save(ctx context.Context, in *Person) (int, error) { db, err := getDB(d.cfg) if err != nil { fmt.Fprintf(os.Stderr, "failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { fmt.Fprintf(os.Stderr, "failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { fmt.Fprintf(os.Stderr, "failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } ================================================ FILE: ch11/01_di_induced_damage/06_inject_sql/data.go ================================================ package data import ( "database/sql" "errors" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // DataDSN returns the data source name DataDSN() string } var getDB = func(cfg Config) (Database, error) { if db == nil { var err error db, err = sql.Open("mysql", cfg.DataDSN()) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return &DatabaseImpl{db: db}, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch11/01_di_induced_damage/06_inject_sql/data_test.go ================================================ package data import ( "context" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func(_ Config) (Database, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func(_ Config) (Database, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) results, err := dao.LoadAll(ctx) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) result, err := dao.Load(ctx, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } type testConfig struct{} // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } ================================================ FILE: ch11/01_di_induced_damage/07_needless_indirection/example_test.go ================================================ package needless_indirection import ( "io/ioutil" "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExample(t *testing.T) { router := http.NewServeMux() router.HandleFunc("/health", func(resp http.ResponseWriter, req *http.Request) { _, _ = resp.Write([]byte(`OK`)) }) // start a server address := ":8080" go func() { _ = http.ListenAndServe(address, router) }() // call the server resp, err := http.Get("http://:8080/health") require.NoError(t, err) // validate the response responseBody, err := ioutil.ReadAll(resp.Body) assert.Equal(t, []byte(`OK`), responseBody) } ================================================ FILE: ch11/01_di_induced_damage/08_needless_indirection/01_mux.go ================================================ package needless_indirection import ( "net/http" ) //go:generate mockery -name=MyMux -case underscore -testonly -inpkg -note @generated type MyMux interface { Handle(pattern string, handler http.Handler) Handler(req *http.Request) (handler http.Handler, pattern string) ServeHTTP(resp http.ResponseWriter, req *http.Request) } // build HTTP handler routing func buildRouter(mux MyMux) { mux.Handle("/get", &getEndpoint{}) mux.Handle("/list", &listEndpoint{}) mux.Handle("/save", &saveEndpoint{}) } type getEndpoint struct{} func (*getEndpoint) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { // not implemented } type listEndpoint struct{} func (*listEndpoint) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { // not implemented } type saveEndpoint struct{} func (*saveEndpoint) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { // not implemented } ================================================ FILE: ch11/01_di_induced_damage/08_needless_indirection/01_mux_test.go ================================================ package needless_indirection import ( "testing" "github.com/stretchr/testify/assert" ) func TestBuildRouter(t *testing.T) { // build mock mockRouter := &MockMyMux{} mockRouter.On("Handle", "/get", &getEndpoint{}).Once() mockRouter.On("Handle", "/list", &listEndpoint{}).Once() mockRouter.On("Handle", "/save", &saveEndpoint{}).Once() // call function buildRouter(mockRouter) // assert expectations assert.True(t, mockRouter.AssertExpectations(t)) } ================================================ FILE: ch11/01_di_induced_damage/08_needless_indirection/mock_my_mux_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package needless_indirection import ( "net/http" "github.com/stretchr/testify/mock" ) // MockMyMux is an autogenerated mock type for the MyMux type type MockMyMux struct { mock.Mock } // Handle provides a mock function with given fields: pattern, handler func (_m *MockMyMux) Handle(pattern string, handler http.Handler) { _m.Called(pattern, handler) } // Handler provides a mock function with given fields: req func (_m *MockMyMux) Handler(req *http.Request) (http.Handler, string) { ret := _m.Called(req) var r0 http.Handler if rf, ok := ret.Get(0).(func(*http.Request) http.Handler); ok { r0 = rf(req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(http.Handler) } } var r1 string if rf, ok := ret.Get(1).(func(*http.Request) string); ok { r1 = rf(req) } else { r1 = ret.Get(1).(string) } return r0, r1 } // ServeHTTP provides a mock function with given fields: resp, req func (_m *MockMyMux) ServeHTTP(resp http.ResponseWriter, req *http.Request) { _m.Called(resp, req) } ================================================ FILE: ch11/01_di_induced_damage/09_needless_indirection/01_mux.go ================================================ package needless_indirection import ( "net/http" ) // build HTTP handler routing func buildRouter(mux *http.ServeMux) { mux.Handle("/get", &getEndpoint{}) mux.Handle("/list", &listEndpoint{}) mux.Handle("/save", &saveEndpoint{}) } type getEndpoint struct{} func (*getEndpoint) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { // not implemented } type listEndpoint struct{} func (*listEndpoint) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { // not implemented } type saveEndpoint struct{} func (*saveEndpoint) ServeHTTP(_ http.ResponseWriter, _ *http.Request) { // not implemented } ================================================ FILE: ch11/01_di_induced_damage/09_needless_indirection/01_mux_test.go ================================================ package needless_indirection import ( "net/http" "testing" "github.com/stretchr/testify/assert" ) func TestBuildRouter(t *testing.T) { router := http.NewServeMux() // call function buildRouter(router) // assertions assert.IsType(t, &getEndpoint{}, extractHandler(router, "/get")) assert.IsType(t, &listEndpoint{}, extractHandler(router, "/list")) assert.IsType(t, &saveEndpoint{}, extractHandler(router, "/save")) } func extractHandler(router *http.ServeMux, path string) http.Handler { req, _ := http.NewRequest("GET", path, nil) handler, _ := router.Handler(req) return handler } ================================================ FILE: ch11/01_di_induced_damage/10_needless_indirection/01_mux_e2e.go ================================================ package needless_indirection import ( "net/http" ) // build HTTP handler routing func buildRouter(mux *http.ServeMux) { mux.Handle("/get", &getEndpoint{}) mux.Handle("/list", &listEndpoint{}) mux.Handle("/save", &saveEndpoint{}) } type getEndpoint struct{} func (*getEndpoint) ServeHTTP(resp http.ResponseWriter, _ *http.Request) { _, _ = resp.Write([]byte(`Hi from Get!`)) } type listEndpoint struct{} func (*listEndpoint) ServeHTTP(resp http.ResponseWriter, _ *http.Request) { _, _ = resp.Write([]byte(`Hi from List!`)) } type saveEndpoint struct{} func (*saveEndpoint) ServeHTTP(resp http.ResponseWriter, _ *http.Request) { _, _ = resp.Write([]byte(`Hi from Save!`)) } ================================================ FILE: ch11/01_di_induced_damage/10_needless_indirection/01_mux_e2e_test.go ================================================ package needless_indirection import ( "io/ioutil" "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestBuildRouter(t *testing.T) { router := http.NewServeMux() // call function buildRouter(router) // start a server address := ":8080" go func() { _ = http.ListenAndServe(address, router) }() // call endpoints responseBody := doGet(t, address+"/get") assert.Equal(t, `Hi from Get!`, responseBody) responseBody = doGet(t, address+"/list") assert.Equal(t, `Hi from List!`, responseBody) responseBody = doGet(t, address+"/save") assert.Equal(t, `Hi from Save!`, responseBody) } func doGet(t *testing.T, address string) string { resp, err := http.Get("http://" + address) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() return string(body) } ================================================ FILE: ch11/01_di_induced_damage/11_service_locator/01_service_locator.go ================================================ package service_locator func NewServiceLocator() *ServiceLocator { return &ServiceLocator{ deps: map[string]interface{}{}, } } type ServiceLocator struct { deps map[string]interface{} } // Store or map a dependency to a key func (s *ServiceLocator) Store(key string, dep interface{}) { s.deps[key] = dep } // Retrieve a dependency by key func (s *ServiceLocator) Get(key string) interface{} { return s.deps[key] } ================================================ FILE: ch11/01_di_induced_damage/11_service_locator/02_usage.go ================================================ package service_locator func Example() { locator := buildServiceLocator() useServiceLocator(locator) } func buildServiceLocator() *ServiceLocator { // build a service locator locator := NewServiceLocator() // load the dependency mappings locator.Store("logger", &myLogger{}) locator.Store("converter", &myConverter{}) return locator } func useServiceLocator(locator *ServiceLocator) { // use the locators to get the logger logger := locator.Get("logger").(Logger) // use the logger logger.Info("Hello World!") } func useServiceLocatorExtended(locator *ServiceLocator) { // use the locators to get the logger loggerRetrieved := locator.Get("logger") if loggerRetrieved == nil { return } logger, ok := loggerRetrieved.(Logger) if !ok { return } // use the logger logger.Info("Hello World!") } type Logger interface { Info(message string, args ...interface{}) } type myLogger struct{} func (m *myLogger) Info(message string, args ...interface{}) { // not implemented } type Converter interface { Convert(int float64) (float64, error) } type myConverter struct{} func (m *myConverter) Convert(in float64) (float64, error) { // not implemented return 0, nil } ================================================ FILE: ch11/02_premature_future/get.go ================================================ package rest import ( "errors" "fmt" "io" "net/http" "strconv" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration type GetModel interface { Do(ID int) (*Person, error) } // GetConfig is the config for the Get Handler type GetConfig interface { Logger() Logger } // Formatter will convert the supplied object to bytes type Formatter interface { Marshal(interface{}) ([]byte, error) } // NewGetHandler is the constructor for GetHandler func NewGetHandler(cfg GetConfig, model GetModel, formatter Formatter) *GetHandler { return &GetHandler{ cfg: cfg, getter: model, formatter: formatter, } } // GetHandler is the HTTP handler for the "Get Person" endpoint type GetHandler struct { cfg GetConfig getter GetModel formatter Formatter } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.buildOutput(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") h.cfg.Logger().Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) h.cfg.Logger().Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person func (h *GetHandler) buildOutput(writer io.Writer, person *Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // build output payload payload, err := h.formatter.Marshal(output) if err != nil { return err } // write payload to response and return _, err = writer.Write(payload) return err } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } type Person struct { ID int FullName string Phone string Currency string Price float64 } type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } ================================================ FILE: ch11/03_mocking_http_requests/converter.go ================================================ package mocking_http_reques import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // Config is the config for Converter type Config interface { Logger() Logger ExchangeBaseURL() string ExchangeAPIKey() string } // NewConverter creates and initializes the converter func NewConverter(cfg Config, requester Requester) *Converter { return &Converter{ cfg: cfg, requester: requester, } } // Converter will convert the base price to the currency supplied type Converter struct { cfg Config requester Requester } // Exchange will perform the conversion func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, c.cfg.ExchangeBaseURL(), c.cfg.ExchangeAPIKey(), currency) // perform request response, err := c.requester.doRequest(ctx, url) if err != nil { c.logger().Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) c.logger().Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) c.logger().Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { c.logger().Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { c.logger().Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } func (c *Converter) logger() Logger { return c.cfg.Logger() } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } // Requester builds and sending HTTP requests //go:generate mockery -name=Requester -case underscore -testonly -inpkg -note @generated type Requester interface { doRequest(ctx context.Context, url string) (*http.Response, error) } // Requesterer is the default implementation of Requester type Requesterer struct { } func (r *Requesterer) doRequest(ctx context.Context, url string) (*http.Response, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request return http.DefaultClient.Do(req) } type Logger interface { Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } type stubLogger struct{} func (l *stubLogger) Warn(message string, args ...interface{}) { // do nothing } func (l *stubLogger) Error(message string, args ...interface{}) { // do nothing } ================================================ FILE: ch11/03_mocking_http_requests/converter_test.go ================================================ package mocking_http_reques import ( "context" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestExchange_invalidResponse(t *testing.T) { // build response response := httptest.NewRecorder() _, err := response.WriteString(`invalid payload`) require.NoError(t, err) // configure mock mockRequester := &mockRequester{} mockRequester.On("doRequest", mock.Anything, mock.Anything).Return(response.Result(), nil).Once() // inputs ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() basePrice := 12.34 currency := "AUD" // perform call converter := &Converter{ requester: mockRequester, cfg: &testConfig{}, } result, resultErr := converter.Exchange(ctx, basePrice, currency) // validate response assert.Equal(t, float64(0), result) assert.Error(t, resultErr) assert.True(t, mockRequester.AssertExpectations(t)) } // stub config that returns known values type testConfig struct { } func (t *testConfig) Logger() Logger { return &stubLogger{} } func (t *testConfig) ExchangeBaseURL() string { return "http://www.example.com" } func (t *testConfig) ExchangeAPIKey() string { return "foo" } ================================================ FILE: ch11/03_mocking_http_requests/mock_requester_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package mocking_http_reques import ( "context" "net/http" "github.com/stretchr/testify/mock" ) // mockRequester is an autogenerated mock type for the requester type type mockRequester struct { mock.Mock } // doRequest provides a mock function with given fields: ctx, url func (_m *mockRequester) doRequest(ctx context.Context, url string) (*http.Response, error) { ret := _m.Called(ctx, url) var r0 *http.Response if rf, ok := ret.Get(0).(func(context.Context, string) *http.Response); ok { r0 = rf(ctx, url) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*http.Response) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, url) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "fmt" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string // environmental dependencies logger logging.Logger } // Logger returns a reference to the singleton logger func (c *Config) Logger() logging.Logger { if c.logger == nil { c.logger = &logging.LoggerStdOut{} } return c.logger } // RegistrationBasePrice returns the base price for registrations func (c *Config) RegistrationBasePrice() float64 { return c.BasePrice } // DataDSN returns the DSN func (c *Config) DataDSN() string { return c.DSN } // ExchangeBaseURL returns the Base URL from which we can load exchange rates func (c *Config) ExchangeBaseURL() string { return c.ExchangeRateBaseURL } // ExchangeAPIKey returns the DSN func (c *Config) ExchangeAPIKey() string { return c.ExchangeRateAPIKey } // BindAddress returns the host and port this service should bind to func (c *Config) BindAddress() string { return c.Address } // Load returns the config loaded from environment func Load() (*Config, error) { filename, found := os.LookupEnv(DefaultEnvVar) if !found { err := fmt.Errorf("failed to locate file specified by %s", DefaultEnvVar) fmt.Fprintf(os.Stderr, err.Error()) return nil, err } cfg, err := load(filename) if err != nil { fmt.Fprintf(os.Stderr, "failed to load config with err %s", err) return nil, err } return cfg, nil } func load(filename string) (*Config, error) { out := &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { fmt.Fprintf(os.Stderr, "failed to read config file. err: %s", err) return nil, err } err = json.Unmarshal(bytes, out) if err != nil { fmt.Fprintf(os.Stderr, "failed to parse config file. err : %s", err) return nil, err } return out, nil } ================================================ FILE: ch11/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: nil, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { result, resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, result, scenario.desc) }) } } ================================================ FILE: ch11/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // Logger is our standard interface type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch11/acme/internal/modules/data/dao.go ================================================ package data import ( "context" "database/sql" "time" ) // NewDAO will initialize the database connection pool (if not already done) and return a data access object which // can be used to interact with the database func NewDAO(cfg Config) *DAO { // initialize the db connection pool _, _ = getDB(cfg) return &DAO{ cfg: cfg, } } // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config // Tracker is an optional query timer Tracker QueryTracker } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { // track processing time defer d.getTracker().Track("Load", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { d.cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) LoadAll(ctx context.Context) ([]*Person, error) { // track processing time defer d.getTracker().Track("LoadAll", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { d.cfg.Logger().Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func (d *DAO) Save(ctx context.Context, in *Person) (int, error) { // track processing time defer d.getTracker().Track("Save", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { d.cfg.Logger().Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { d.cfg.Logger().Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } func (d *DAO) getTracker() QueryTracker { if d.Tracker == nil { d.Tracker = &noopTracker{} } return d.Tracker } ================================================ FILE: ch11/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // Logger returns a reference to the logger Logger() logging.Logger // DataDSN returns the data source name DataDSN() string } var getDB = func(cfg Config) (*sql.DB, error) { if db == nil { var err error db, err = sql.Open("mysql", cfg.DataDSN()) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch11/acme/internal/modules/data/data_test.go ================================================ package data import ( "context" "database/sql" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func(_ Config) (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func(_ Config) (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) results, err := dao.LoadAll(ctx) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) result, err := dao.Load(ctx, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } type testConfig struct{} // Logger implements Config func (t *testConfig) Logger() logging.Logger { return logging.LoggerStdOut{} } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } ================================================ FILE: ch11/acme/internal/modules/data/tracker.go ================================================ package data import ( "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" ) // QueryTracker is an interface to track query timing type QueryTracker interface { // Track will record/out the time a query took by calculating time.Now().Sub(start) Track(key string, start time.Time) } // NO-OP implementation of QueryTracker type noopTracker struct{} // Track implements QueryTracker func (_ *noopTracker) Track(_ string, _ time.Time) { // intentionally does nothing } // NewLogTracker returns a Tracker that outputs tracking data to log func NewLogTracker(logger logging.Logger) *LogTracker { return &LogTracker{ logger: logger, } } // LogTracker implements QueryTracker and outputs to the supplied logger type LogTracker struct { logger logging.Logger } // Track implements QueryTracker func (l *LogTracker) Track(key string, start time.Time) { l.logger.Info("[%s] Timing: %s\n", key, time.Now().Sub(start).String()) } ================================================ FILE: ch11/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // NewConverter creates and initializes the converter func NewConverter(cfg Config) *Converter { return &Converter{ cfg: cfg, } } // Config is the config for Converter type Config interface { Logger() logging.Logger ExchangeBaseURL() string ExchangeAPIKey() string } // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct { cfg Config } // Exchange will perform the conversion func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, c.cfg.ExchangeBaseURL(), c.cfg.ExchangeAPIKey(), currency) // perform request req, err := http.NewRequest("GET", url, nil) if err != nil { c.logger().Warn("[exchange] failed to create request. err: %s", err) return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request response, err := http.DefaultClient.Do(req) if err != nil { c.logger().Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) c.logger().Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) c.logger().Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { c.logger().Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { c.logger().Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } func (c *Converter) logger() logging.Logger { return c.cfg.Logger() } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch11/acme/internal/modules/exchange/converter_ext_bounday_test.go ================================================ // +build external package exchange import ( "context" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExternalBoundaryTest(t *testing.T) { // define the config cfg, err := config.Load() require.NoError(t, err) // create a converter to test converter := NewConverter(cfg) // fetch from the server response, err := converter.loadRateFromServer(context.Background(), "AUD") require.NotNil(t, response) require.NoError(t, err) // parse the response resultRate, err := converter.extractRate(response, "AUD") require.NoError(t, err) // validate the result assert.True(t, resultRate > 0) } ================================================ FILE: ch11/acme/internal/modules/exchange/converter_int_bounday_test.go ================================================ package exchange import ( "context" "net/http" "net/http/httptest" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/stretchr/testify/assert" ) func TestInternalBoundaryTest(t *testing.T) { // start our test server server := httptest.NewServer(&happyExchangeRateService{}) defer server.Close() // define the config cfg := &testConfig{ baseURL: server.URL, apiKey: "", } // create a converter to test converter := NewConverter(cfg) resultRate, resultErr := converter.Exchange(context.Background(), 100.00, "AUD") // validate the result assert.Equal(t, 101.01, resultRate) assert.NoError(t, resultErr) } type happyExchangeRateService struct{} // ServeHTTP implements http.Handler func (*happyExchangeRateService) ServeHTTP(response http.ResponseWriter, request *http.Request) { payload := []byte(` { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } }`) response.Write(payload) } func TestExchange_invalidResponseFromServer(t *testing.T) { // start our test server server := httptest.NewServer(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { payload := []byte(`invalid payload`) response.Write(payload) })) defer server.Close() // inputs ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() basePrice := 12.34 currency := "AUD" cfg := &testConfig{ baseURL: server.URL, apiKey: "", } converter := NewConverter(cfg) result, resultErr := converter.Exchange(ctx, basePrice, currency) // validate response assert.Equal(t, float64(0), result) assert.Error(t, resultErr) } // test implementation of Config type testConfig struct { baseURL string apiKey string } // Logger implements Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // ExchangeBaseURL implements Config func (t *testConfig) ExchangeBaseURL() string { return t.baseURL } // ExchangeAPIKey implements Config func (t *testConfig) ExchangeAPIKey() string { return t.apiKey } ================================================ FILE: ch11/acme/internal/modules/get/get.go ================================================ package get import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // NewGetter creates and initializes a Getter func NewGetter(cfg Config) *Getter { return &Getter{ cfg: cfg, } } // Config is the configuration for Getter type Config interface { Logger() logging.Logger DataDSN() string } // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { cfg Config data myLoader } // Do will perform the get func (g *Getter) Do(ID int) (*Person, error) { // load person from the data layer person, err := g.getLoader().Load(context.TODO(), ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return g.convert(person), err } func (g *Getter) getLoader() myLoader { if g.data == nil { g.data = data.NewDAO(g.cfg) } return g.data } func (g *Getter) convert(in *data.Person) *Person { return &Person{ ID: in.ID, Currency: in.Currency, FullName: in.FullName, Phone: in.Phone, Price: in.Price, } } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { Load(ctx context.Context, ID int) (*data.Person, error) } // Person is a copy/sub-set of data.Person so that the relationship does not leak. // It also allows us to remove/hide and internal fields type Person struct { ID int FullName string Phone string Currency string Price float64 } ================================================ FILE: ch11/acme/internal/modules/get/go_test.go ================================================ package get import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockResult := &data.Person{ ID: 1234, FullName: "Doug", } mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(mockResult, nil).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, data.ErrNotFound).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, errors.New("something failed")).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch11/acme/internal/modules/get/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package get import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // Load provides a mock function with given fields: ctx, ID func (_m *mockMyLoader) Load(ctx context.Context, ID int) (*data.Person, error) { ret := _m.Called(ctx, ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(context.Context, int) *data.Person); ok { r0 = rf(ctx, ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { r1 = rf(ctx, ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/modules/list/list.go ================================================ package list import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // NewLister creates and initializes a Lister func NewLister(cfg Config) *Lister { return &Lister{ cfg: cfg, } } // Config is the config for Lister type Config interface { Logger() logging.Logger DataDSN() string } // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { cfg Config data myLoader } // Exchange will load the people from the data layer func (l *Lister) Do() ([]*Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return l.convert(people), nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := l.getLoader().LoadAll(context.TODO()) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } func (l *Lister) getLoader() myLoader { if l.data == nil { l.data = data.NewDAO(l.cfg) // temporarily add a log tracker l.data.(*data.DAO).Tracker = data.NewLogTracker(l.cfg.Logger()) } return l.data } func (l *Lister) convert(in []*data.Person) []*Person { out := make([]*Person, len(in)) for index, thisRecord := range in { out[index] = &Person{ ID: thisRecord.ID, FullName: thisRecord.FullName, Phone: thisRecord.Phone, } } return out } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { LoadAll(ctx context.Context) ([]*data.Person, error) } // Person is a copy/sub-set of data.Person so that the relationship does not leak. // It also allows us to remove/hide and internal fields type Person struct { ID int FullName string Phone string } ================================================ FILE: ch11/acme/internal/modules/list/list_test.go ================================================ package list import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // configure the mock loader mockResult := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(mockResult, nil).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_noResults(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, data.ErrNotFound).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_error(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, errors.New("something failed")).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch11/acme/internal/modules/list/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package list import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // LoadAll provides a mock function with given fields: ctx func (_m *mockMyLoader) LoadAll(ctx context.Context) ([]*data.Person, error) { ret := _m.Called(ctx) var r0 []*data.Person if rf, ok := ret.Get(0).(func(context.Context) []*data.Person); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/modules/register/mock_exchanger_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package register import ( "context" "github.com/stretchr/testify/mock" ) // MockExchanger is an autogenerated mock type for the Exchanger type type MockExchanger struct { mock.Mock } // Exchange provides a mock function with given fields: ctx, basePrice, currency func (_m *MockExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { ret := _m.Called(ctx, basePrice, currency) var r0 float64 if rf, ok := ret.Get(0).(func(context.Context, float64, string) float64); ok { r0 = rf(ctx, basePrice, currency) } else { r0 = ret.Get(0).(float64) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, float64, string) error); ok { r1 = rf(ctx, basePrice, currency) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/modules/register/mock_my_saver_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package register import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMySaver is an autogenerated mock type for the mySaver type type mockMySaver struct { mock.Mock } // Save provides a mock function with given fields: ctx, in func (_m *mockMySaver) Save(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/modules/register/register.go ================================================ package register import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/data" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // NewRegisterer creates and initializes a Registerer func NewRegisterer(cfg Config, exchanger Exchanger) *Registerer { return &Registerer{ cfg: cfg, exchanger: exchanger, } } // Exchanger will convert from one currency to another //go:generate mockery -name=Exchanger -case underscore -testonly -inpkg -note @generated type Exchanger interface { // Exchange will perform the conversion Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) } // Config is the configuration for the Registerer type Config interface { Logger() logging.Logger RegistrationBasePrice() float64 DataDSN() string } // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { cfg Config exchanger Exchanger data mySaver } // Do is API for this struct func (r *Registerer) Do(ctx context.Context, in *Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { r.logger().Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(ctx, in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(ctx, r.convert(in), price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { price, err := r.exchanger.Exchange(ctx, r.cfg.RegistrationBasePrice(), currency) if err != nil { r.logger().Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return r.getSaver().Save(ctx, person) } func (r *Registerer) getSaver() mySaver { if r.data == nil { r.data = data.NewDAO(r.cfg) } return r.data } func (r *Registerer) logger() logging.Logger { return r.cfg.Logger() } func (r *Registerer) convert(in *Person) *data.Person { return &data.Person{ ID: in.ID, Currency: in.Currency, FullName: in.FullName, Phone: in.Phone, Price: in.Price, } } //go:generate mockery -name=mySaver -case underscore -testonly -inpkg -note @generated type mySaver interface { Save(ctx context.Context, in *data.Person) (int, error) } // Person is a copy/sub-set of data.Person so that the relationship does not leak. // It also allows us to remove/hide and internal fields type Person struct { ID int FullName string Phone string Currency string Price float64 } ================================================ FILE: ch11/acme/internal/modules/register/register_test.go ================================================ package register import ( "context" "errors" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // configure the mock saver mockResult := 888 mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(mockResult, nil).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) assert.True(t, mockSaver.AssertExpectations(t)) } func TestRegisterer_Do_error(t *testing.T) { // configure the mock saver mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(0, errors.New("something failed")).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) assert.True(t, mockSaver.AssertExpectations(t)) } func TestRegisterer_Do_exchangeError(t *testing.T) { // configure the mocks mockSaver := &mockMySaver{} mockExchanger := &MockExchanger{} mockExchanger. On("Exchange", mock.Anything, mock.Anything, mock.Anything). Return(0.0, errors.New("failed to load conversion")). Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: mockExchanger, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) assert.True(t, mockSaver.AssertExpectations(t)) assert.True(t, mockExchanger.AssertExpectations(t)) } // Stub implementation of Config type testConfig struct{} // Logger implement Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // RegistrationBasePrice implement Config func (t *testConfig) RegistrationBasePrice() float64 { return 12.34 } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } type stubExchanger struct{} // Exchange implements Exchanger func (s stubExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { return 12.34, nil } ================================================ FILE: ch11/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/get" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*get.Person, error) } // GetConfig is the config for the Get Handler type GetConfig interface { Logger() logging.Logger } // NewGetHandler is the constructor for GetHandler func NewGetHandler(cfg GetConfig, model GetModel) *GetHandler { return &GetHandler{ cfg: cfg, getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { cfg GetConfig getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") h.cfg.Logger().Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) h.cfg.Logger().Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *get.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch11/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/get" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &get.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(&testConfig{}, mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } type testConfig struct { } func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } func (*testConfig) BindAddress() string { return "0.0.0.0:0" } ================================================ FILE: ch11/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/list" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*list.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*list.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch11/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/list" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*list.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data var output []*list.Person mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch11/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/get" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*get.Person, error) { ret := _m.Called(ID) var r0 *get.Person if rf, ok := ret.Get(0).(func(int) *get.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*get.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/list" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*list.Person, error) { ret := _m.Called() var r0 []*list.Person if rf, ok := ret.Get(0).(func() []*list.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*list.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/register" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: ctx, in func (_m *MockRegisterModel) Do(ctx context.Context, in *register.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *register.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *register.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch11/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch11/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch11/acme/internal/rest/register.go ================================================ package rest import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/register" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(ctx context.Context, in *register.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // set latency budget for this API subCtx, cancel := context.WithTimeout(request.Context(), 1500*time.Millisecond) defer cancel() // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // call the business logic using the request data and context id, err := h.register(subCtx, requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(ctx context.Context, requestPayload *registerRequest) (int, error) { person := ®ister.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(ctx, person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch11/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch11/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/logging" "github.com/gorilla/mux" ) // Config is the config for the REST package type Config interface { Logger() logging.Logger BindAddress() string } // New will create and initialize the server func New(cfg Config, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: cfg.BindAddress(), handlerGet: NewGetHandler(cfg, getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch11/acme/main.go ================================================ package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/rest" "github.com/google/wire" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server, err := initializeServer() if err != nil { os.Exit(-1) } server.Listen(ctx.Done()) } // List of wire enabled objects var wireSetWithoutConfig = wire.NewSet( // *exchange.Converter exchange.NewConverter, // *get.Getter get.NewGetter, // *list.Lister list.NewLister, // *register.Registerer wire.Bind(new(register.Exchanger), &exchange.Converter{}), register.NewRegisterer, // *rest.Server wire.Bind(new(rest.GetModel), &get.Getter{}), wire.Bind(new(rest.ListModel), &list.Lister{}), wire.Bind(new(rest.RegisterModel), ®ister.Registerer{}), rest.New, ) var wireSet = wire.NewSet( wireSetWithoutConfig, // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), // *get.Getter wire.Bind(new(get.Config), &config.Config{}), // *list.Lister wire.Bind(new(list.Config), &config.Config{}), // *register.Registerer wire.Bind(new(register.Config), &config.Config{}), // *rest.Server wire.Bind(new(rest.Config), &config.Config{}), ) ================================================ FILE: ch11/acme/main_test.go ================================================ package main import ( "bytes" "context" "errors" "fmt" "net" "net/http" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegister(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request payload := bytes.NewBufferString(` { "fullName": "Bob", "phone": "0123456789", "currency": "AUD" } `) req, err := http.NewRequest("POST", serverAddress+"/person/register", payload) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusCreated, resp.StatusCode) assert.NotEmpty(t, resp.Header.Get("Location")) } func TestGet(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request req, err := http.NewRequest("GET", serverAddress+"/person/1/", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusOK, resp.StatusCode) } func TestList(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request req, err := http.NewRequest("GET", serverAddress+"/person/list", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusOK, resp.StatusCode) } func startTestServer(t *testing.T, ctx context.Context) string { // load the standard config (from the ENV) cfg, err := config.Load() require.NoError(t, err) // get a free port (so tests can run concurrently) port, err := getFreePort() require.NoError(t, err) // override config port with free one cfg.Address = net.JoinHostPort("0.0.0.0", port) // start the test server on a random port go func() { // start REST server server := initializeServerCustomConfig(cfg, cfg, cfg, cfg, cfg) server.Listen(ctx.Done()) }() // give the server a chance to start <-time.After(100 * time.Millisecond) // return the address of the test server return "http://" + cfg.Address } func getFreePort() (string, error) { for attempt := 0; attempt <= 10; attempt++ { addr := net.JoinHostPort("", "0") listener, err := net.Listen("tcp", addr) if err != nil { continue } port, err := getPort(listener.Addr()) if err != nil { continue } // close/free the port tcpListener := listener.(*net.TCPListener) cErr := tcpListener.Close() if cErr == nil { file, fErr := tcpListener.File() if fErr == nil { // ignore any errors cleaning up the file _ = file.Close() } return port, nil } } return "", errors.New("no free ports") } func getPort(addr fmt.Stringer) (string, error) { actualAddress := addr.String() _, port, err := net.SplitHostPort(actualAddress) return port, err } ================================================ FILE: ch11/acme/wire.go ================================================ //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/rest" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeServer() (*rest.Server, error) { wire.Build(wireSet) return nil, nil } func initializeServerCustomConfig(_ exchange.Config, _ get.Config, _ list.Config, _ register.Config, _ rest.Config) *rest.Server { wire.Build(wireSetWithoutConfig) return nil } ================================================ FILE: ch11/acme/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch11/acme/internal/rest" ) // Injectors from wire.go: func initializeServer() (*rest.Server, error) { configConfig, err := config.Load() if err != nil { return nil, err } getter := get.NewGetter(configConfig) lister := list.NewLister(configConfig) converter := exchange.NewConverter(configConfig) registerer := register.NewRegisterer(configConfig, converter) server := rest.New(configConfig, getter, lister, registerer) return server, nil } func initializeServerCustomConfig(exchangeConfig exchange.Config, getConfig get.Config, listConfig list.Config, registerConfig register.Config, restConfig rest.Config) *rest.Server { getter := get.NewGetter(getConfig) lister := list.NewLister(listConfig) converter := exchange.NewConverter(exchangeConfig) registerer := register.NewRegisterer(registerConfig, converter) server := rest.New(restConfig, getter, lister, registerer) return server } ================================================ FILE: ch12/01_improvements/01_test_logging_test.go ================================================ package improvements import ( "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLogging(t *testing.T) { // build log recorder recorder := &LogRecorder{} // Call struct that uses a logger calculator := &Calculator{ logger: recorder, } result := calculator.divide(10, 0) // validate expectations, including that the logger was called assert.Equal(t, 0, result) require.Equal(t, 1, len(recorder.Logs)) assert.Equal(t, "cannot divide by 0", recorder.Logs[0]) } type Calculator struct { logger Logger } func (c *Calculator) divide(dividend int, divisor int) int { if divisor == 0 { c.logger.Error("cannot divide by 0") return 0 } return dividend / divisor } // Logger is our standard interface type Logger interface { Error(message string, args ...interface{}) } // LogRecorder implements Logger interface type LogRecorder struct { Logs []string } func (l *LogRecorder) Error(message string, args ...interface{}) { // build log message logMessage := fmt.Sprintf(message, args...) // record log message l.Logs = append(l.Logs, logMessage) } ================================================ FILE: ch12/03_testing/01_mock_get_model.go ================================================ package testing import ( "github.com/stretchr/testify/mock" ) type MockGetModel struct { mock.Mock } func (_m *MockGetModel) Do(ID int) (*Person, error) { outputs := _m.Called(ID) if outputs.Get(0) != nil { return outputs.Get(0).(*Person), outputs.Error(1) } return nil, outputs.Error(1) } type Person struct { ID int FullName string Phone string Currency string Price float64 } ================================================ FILE: ch12/03_testing/02_coverage_ch04.txt ================================================ ---------------------------------------------------------------------------- | Branch | Dir | | | Cov% | Stmts | Cov% | Stmts | Package | ---------------------------------------------------------------------------- | 52.94 | 238 | 0.00 | 3 | acme/ | | 73.33 | 15 | 73.33 | 15 | acme/internal/config/ | | 0.00 | 4 | 0.00 | 4 | acme/internal/logging/ | | 63.33 | 60 | 63.33 | 60 | acme/internal/modules/data/ | | 0.00 | 38 | 0.00 | 38 | acme/internal/modules/exchange/ | | 50.00 | 6 | 50.00 | 6 | acme/internal/modules/get/ | | 25.00 | 12 | 25.00 | 12 | acme/internal/modules/list/ | | 64.29 | 28 | 64.29 | 28 | acme/internal/modules/register/ | | 73.61 | 72 | 73.61 | 72 | acme/internal/rest/ | ---------------------------------------------------------------------------- ================================================ FILE: ch12/03_testing/03_coverage_ch11.txt ================================================ ---------------------------------------------------------------------------- | Branch | Dir | | | Cov% | Stmts | Cov% | Stmts | Package | ---------------------------------------------------------------------------- | 63.11 | 309 | 30.00 | 20 | acme/ | | 28.57 | 28 | 28.57 | 28 | acme/internal/config/ | | 0.00 | 4 | 0.00 | 4 | acme/internal/logging/ | | 74.65 | 71 | 74.65 | 71 | acme/internal/modules/data/ | | 61.70 | 47 | 61.70 | 47 | acme/internal/modules/exchange/ | | 81.82 | 11 | 81.82 | 11 | acme/internal/modules/get/ | | 38.10 | 21 | 38.10 | 21 | acme/internal/modules/list/ | | 75.76 | 33 | 75.76 | 33 | acme/internal/modules/register/ | | 77.03 | 74 | 77.03 | 74 | acme/internal/rest/ | ---------------------------------------------------------------------------- ================================================ FILE: ch12/03_testing/04_coverage_config.htm ================================================
not tracked not covered covered
package config

import (
        "encoding/json"
        "fmt"
        "io/ioutil"
        "os"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging"
)

// DefaultEnvVar is the default environment variable the points to the config file
const DefaultEnvVar = "ACME_CONFIG"

// Config defines the JSON format for the config file
type Config struct {
        // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name)
        DSN string

        // Address is the IP address and port to bind this rest to
        Address string

        // BasePrice is the price of registration
        BasePrice float64

        // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate
        ExchangeRateBaseURL string

        // ExchangeRateAPIKey is the API for the exchange rate API
        ExchangeRateAPIKey string

        // environmental dependencies
        logger logging.Logger
}

// Logger returns a reference to the singleton logger
func (c *Config) Logger() logging.Logger {
        if c.logger == nil {
                c.logger = &logging.LoggerStdOut{}
        }

        return c.logger
}

// RegistrationBasePrice returns the base price for registrations
func (c *Config) RegistrationBasePrice() float64 {
        return c.BasePrice
}

// DataDSN returns the DSN
func (c *Config) DataDSN() string {
        return c.DSN
}

// ExchangeBaseURL returns the Base URL from which we can load exchange rates
func (c *Config) ExchangeBaseURL() string {
        return c.ExchangeRateBaseURL
}

// ExchangeAPIKey returns the DSN
func (c *Config) ExchangeAPIKey() string {
        return c.ExchangeRateAPIKey
}

// BindAddress returns the host and port this service should bind to
func (c *Config) BindAddress() string {
        return c.Address
}

// Load returns the config loaded from environment
func Load() (*Config, error) {
        filename, found := os.LookupEnv(DefaultEnvVar)
        if !found {
                err := fmt.Errorf("failed to locate file specified by %s", DefaultEnvVar)
                fmt.Fprintf(os.Stderr, err.Error())
                return nil, err
        }

        cfg, err := load(filename)
        if err != nil {
                fmt.Fprintf(os.Stderr, "failed to load config with err %s", err)
                return nil, err
        }

        return cfg, nil
}

func load(filename string) (*Config, error) {
        out := &Config{}
        bytes, err := ioutil.ReadFile(filename)
        if err != nil {
                fmt.Fprintf(os.Stderr, "failed to read config file. err: %s", err)
                return nil, err
        }

        err = json.Unmarshal(bytes, out)
        if err != nil {
                fmt.Fprintf(os.Stderr, "failed to parse config file. err : %s", err)
                return nil, err
        }

        return out, nil
}
================================================ FILE: ch12/03_testing/04_coverage_data.htm ================================================
not tracked not covered covered
package data

import (
        "context"
        "database/sql"
        "time"
)

// NewDAO will initialize the database connection pool (if not already done) and return a data access object which
// can be used to interact with the database
func NewDAO(cfg Config) *DAO {
        // initialize the db connection pool
        _, _ = getDB(cfg)

        return &DAO{
                cfg: cfg,
        }
}

// DAO is a data access object that provides an abstraction over our database interactions.
type DAO struct {
        cfg Config

        // Tracker is an optional query timer
        Tracker QueryTracker
}

// Load will attempt to load and return a person.
// It will return ErrNotFound when the requested person does not exist.
// Any other errors returned are caused by the underlying database or our connection to it.
func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) {
        // track processing time
        defer d.getTracker().Track("Load", time.Now())

        db, err := getDB(d.cfg)
        if err != nil {
                d.cfg.Logger().Error("failed to get DB connection. err: %s", err)
                return nil, err
        }

        // set latency budget for the database call
        subCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
        defer cancel()

        // perform DB select
        row := db.QueryRowContext(subCtx, sqlLoadByID, ID)

        // retrieve columns and populate the person object
        out, err := populatePerson(row.Scan)
        if err != nil {
                if err == sql.ErrNoRows {
                        d.cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err)
                        return nil, ErrNotFound
                }

                d.cfg.Logger().Error("failed to convert query result. err: %s", err)
                return nil, err
        }
        return out, nil
}

// LoadAll will attempt to load all people in the database
// It will return ErrNotFound when there are not people in the database
// Any other errors returned are caused by the underlying database or our connection to it.
func (d *DAO) LoadAll(ctx context.Context) ([]*Person, error) {
        // track processing time
        defer d.getTracker().Track("LoadAll", time.Now())

        db, err := getDB(d.cfg)
        if err != nil {
                d.cfg.Logger().Error("failed to get DB connection. err: %s", err)
                return nil, err
        }

        // set latency budget for the database call
        subCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
        defer cancel()

        // perform DB select
        rows, err := db.QueryContext(subCtx, sqlLoadAll)
        if err != nil {
                return nil, err
        }
        defer func() {
                _ = rows.Close()
        }()

        var out []*Person

        for rows.Next() {
                // retrieve columns and populate the person object
                record, err := populatePerson(rows.Scan)
                if err != nil {
                        d.cfg.Logger().Error("failed to convert query result. err: %s", err)
                        return nil, err
                }

                out = append(out, record)
        }

        if len(out) == 0 {
                d.cfg.Logger().Warn("no people found in the database.")
                return nil, ErrNotFound
        }

        return out, nil
}

// Save will save the supplied person and return the ID of the newly created person or an error.
// Errors returned are caused by the underlying database or our connection to it.
func (d *DAO) Save(ctx context.Context, in *Person) (int, error) {
        // track processing time
        defer d.getTracker().Track("Save", time.Now())

        db, err := getDB(d.cfg)
        if err != nil {
                d.cfg.Logger().Error("failed to get DB connection. err: %s", err)
                return defaultPersonID, err
        }

        // set latency budget for the database call
        subCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
        defer cancel()

        // perform DB insert
        result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price)
        if err != nil {
                d.cfg.Logger().Error("failed to save person into DB. err: %s", err)
                return defaultPersonID, err
        }

        // retrieve and return the ID of the person created
        id, err := result.LastInsertId()
        if err != nil {
                d.cfg.Logger().Error("failed to retrieve id of last saved person. err: %s", err)
                return defaultPersonID, err
        }

        return int(id), nil
}

func (d *DAO) getTracker() QueryTracker {
        if d.Tracker == nil {
                d.Tracker = &noopTracker{}
        }

        return d.Tracker
}
================================================ FILE: ch12/03_testing/04_coverage_exchange.htm ================================================
not tracked not covered covered
package exchange

import (
        "context"
        "encoding/json"
        "fmt"
        "io/ioutil"
        "math"
        "net/http"
        "time"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging"
)

const (
        // request URL for the exchange rate API
		urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s"

        // default price that is sent when an error occurs
        defaultPrice = 0.0
)

// NewConverter creates and initializes the converter
func NewConverter(cfg Config) *Converter {
        return &Converter{
                cfg: cfg,
        }
}

// Config is the config for Converter
type Config interface {
        Logger() logging.Logger
        ExchangeBaseURL() string
        ExchangeAPIKey() string
}

// Converter will convert the base price to the currency supplied
// Note: we are expecting sane inputs and therefore skipping input validation
type Converter struct {
        cfg Config
}

// Exchange will perform the conversion
func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) {
        // load rate from the external API
        response, err := c.loadRateFromServer(ctx, currency)
        if err != nil {
                return defaultPrice, err
        }

        // extract rate from response
        rate, err := c.extractRate(response, currency)
        if err != nil {
                return defaultPrice, err
        }

        // apply rate and round to 2 decimal places
        return math.Floor((basePrice/rate)*100) / 100, nil
}

// load rate from the external API
func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) {
        // build the request
        url := fmt.Sprintf(urlFormat,
                c.cfg.ExchangeBaseURL(),
                c.cfg.ExchangeAPIKey(),
                currency)

        // perform request
        req, err := http.NewRequest("GET", url, nil)
        if err != nil {
                c.logger().Warn("[exchange] failed to create request. err: %s", err)
                return nil, err
        }

        // set latency budget for the upstream call
        subCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
        defer cancel()

        // replace the default context with our custom one
        req = req.WithContext(subCtx)

        // perform the HTTP request
        response, err := http.DefaultClient.Do(req)
        if err != nil {
                c.logger().Warn("[exchange] failed to load. err: %s", err)
                return nil, err
        }

        if response.StatusCode != http.StatusOK {
                err = fmt.Errorf("request failed with code %d", response.StatusCode)
                c.logger().Warn("[exchange] %s", err)
                return nil, err
        }

        return response, nil
}

func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) {
        defer func() {
                _ = response.Body.Close()
        }()

        // extract data from response
        data, err := c.extractResponse(response)
        if err != nil {
                return defaultPrice, err
        }

        // pull rate from response data
        rate, found := data.Quotes["USD" + currency]
        if !found {
                err = fmt.Errorf("response did not include expected currency '%s'", currency)
                c.logger().Error("[exchange] %s", err)
                return defaultPrice, err
        }

        // happy path
        return rate, nil
}

func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) {
        payload, err := ioutil.ReadAll(response.Body)
        if err != nil {
                c.logger().Error("[exchange] failed to ready response body. err: %s", err)
                return nil, err
        }

        data := &apiResponseFormat{}
        err = json.Unmarshal(payload, data)
        if err != nil {
                c.logger().Error("[exchange] error converting response. err: %s", err)
                return nil, err
        }

        // happy path
        return data, nil
}

func (c *Converter) logger() logging.Logger {
        return c.cfg.Logger()
}

// the response format from the exchange rate API
type apiResponseFormat struct {
        Quotes map[string]float64 `json:"quotes"`
}
================================================ FILE: ch12/03_testing/04_coverage_get.htm ================================================
not tracked not covered covered
package get

import (
        "context"
        "errors"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data"
)

var (
        // error thrown when the requested person is not in the database
        errPersonNotFound = errors.New("person not found")
)

// NewGetter creates and initializes a Getter
func NewGetter(cfg Config) *Getter {
        return &Getter{
                cfg: cfg,
        }
}

// Config is the configuration for Getter
type Config interface {
        Logger() logging.Logger
        DataDSN() string
}

// Getter will attempt to load a person.
// It can return an error caused by the data layer or when the requested person is not found
type Getter struct {
        cfg  Config
        data myLoader
}

// Do will perform the get
func (g *Getter) Do(ID int) (*Person, error) {
        // load person from the data layer
        person, err := g.getLoader().Load(context.TODO(), ID)
        if err != nil {
                if err == data.ErrNotFound {
                        // By converting the error we are hiding the implementation details from our users.
                        return nil, errPersonNotFound
                }
                return nil, err
        }

        return g.convert(person), err
}

func (g *Getter) getLoader() myLoader {
        if g.data == nil {
                g.data = data.NewDAO(g.cfg)
        }

        return g.data
}

func (g *Getter) convert(in *data.Person) *Person {
        return &Person{
                ID:       in.ID,
                Currency: in.Currency,
                FullName: in.FullName,
                Phone:    in.Phone,
                Price:    in.Price,
        }
}

//go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated
type myLoader interface {
        Load(ctx context.Context, ID int) (*data.Person, error)
}

// Person is a copy/sub-set of data.Person so that the relationship does not leak.
// It also allows us to remove/hide and internal fields
type Person struct {
        ID       int
        FullName string
        Phone    string
        Currency string
        Price    float64
}
================================================ FILE: ch12/03_testing/04_coverage_list.htm ================================================
not tracked not covered covered
package list

import (
        "context"
        "errors"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data"
)

var (
        // error thrown when there are no people in the database
        errPeopleNotFound = errors.New("no people found")
)

// NewLister creates and initializes a Lister
func NewLister(cfg Config) *Lister {
        return &Lister{
                cfg: cfg,
        }
}

// Config is the config for Lister
type Config interface {
        Logger() logging.Logger
        DataDSN() string
}

// Lister will attempt to load all people in the database.
// It can return an error caused by the data layer
type Lister struct {
        cfg  Config
        data myLoader
}

// Exchange will load the people from the data layer
func (l *Lister) Do() ([]*Person, error) {
        // load all people
        people, err := l.load()
        if err != nil {
                return nil, err
        }

        if len(people) == 0 {
                // special processing for 0 people returned
                return nil, errPeopleNotFound
        }

        return l.convert(people), nil
}

// load all people
func (l *Lister) load() ([]*data.Person, error) {
        people, err := l.getLoader().LoadAll(context.TODO())
        if err != nil {
                if err == data.ErrNotFound {
                        // By converting the error we are encapsulating the implementation details from our users.
                        return nil, errPeopleNotFound
                }
                return nil, err
        }

        return people, nil
}

func (l *Lister) getLoader() myLoader {
        if l.data == nil {
                l.data = data.NewDAO(l.cfg)

                // temporarily add a log tracker
                l.data.(*data.DAO).Tracker = data.NewLogTracker(l.cfg.Logger())
        }

        return l.data
}

func (l *Lister) convert(in []*data.Person) []*Person {
        out := make([]*Person, len(in))

        for index, thisRecord := range in {
                out[index] = &Person{
                        ID:       thisRecord.ID,
                        FullName: thisRecord.FullName,
                        Phone:    thisRecord.Phone,
                }
        }

        return out
}

//go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated
type myLoader interface {
        LoadAll(ctx context.Context) ([]*data.Person, error)
}

// Person is a copy/sub-set of data.Person so that the relationship does not leak.
// It also allows us to remove/hide and internal fields
type Person struct {
        ID       int
        FullName string
        Phone    string
}
================================================ FILE: ch12/03_testing/04_coverage_main.htm ================================================
not tracked not covered covered
package main

import (
        "context"
        "os"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/config"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/exchange"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/register"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/rest"
        "github.com/google/wire"
)

func main() {
        // bind stop channel to context
        ctx := context.Background()

        // start REST server
        server, err := initializeServer()
        if err != nil {
                os.Exit(-1)
        }

        server.Listen(ctx.Done())
}

// List of wire enabled objects
var wireSetWithoutConfig = wire.NewSet(
        // *exchange.Converter
        exchange.NewConverter,

        // *get.Getter
        get.NewGetter,

        // *list.Lister
        list.NewLister,

        // *register.Registerer
        wire.Bind(new(register.Exchanger), &exchange.Converter{}),
        register.NewRegisterer,

        // *rest.Server
        wire.Bind(new(rest.GetModel), &get.Getter{}),
        wire.Bind(new(rest.ListModel), &list.Lister{}),
        wire.Bind(new(rest.RegisterModel), &register.Registerer{}),
        rest.New,
)

var wireSet = wire.NewSet(
        wireSetWithoutConfig,

        // *config.Config
        config.Load,

        // *exchange.Converter
        wire.Bind(new(exchange.Config), &config.Config{}),

        // *get.Getter
        wire.Bind(new(get.Config), &config.Config{}),

        // *list.Lister
        wire.Bind(new(list.Config), &config.Config{}),

        // *register.Registerer
        wire.Bind(new(register.Config), &config.Config{}),

        // *rest.Server
        wire.Bind(new(rest.Config), &config.Config{}),
)
================================================ FILE: ch12/03_testing/04_coverage_register.htm ================================================
not tracked not covered covered
package register

import (
        "context"
        "errors"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data"
)

const (
        // default person id (returned on error)
        defaultPersonID = 0
)

var (
        // validation errors
        errNameMissing     = errors.New("name is missing")
        errPhoneMissing    = errors.New("phone is missing")
        errCurrencyMissing = errors.New("currency is missing")
        errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD")

        // a little trick to make checking for supported currencies easier
        supportedCurrencies = map[string]struct{}{
                "AUD": {},
                "CNY": {},
                "EUR": {},
                "GBP": {},
                "JPY": {},
                "MYR": {},
                "SGD": {},
                "USD": {},
        }
)

// NewRegisterer creates and initializes a Registerer
func NewRegisterer(cfg Config, exchanger Exchanger) *Registerer {
        return &Registerer{
                cfg:       cfg,
                exchanger: exchanger,
        }
}

// Exchanger will convert from one currency to another
//go:generate mockery -name=Exchanger -case underscore -testonly -inpkg -note @generated
type Exchanger interface {
        // Exchange will perform the conversion
        Exchange(ctx context.Context, basePrice float64, currency string) (float64, error)
}

// Config is the configuration for the Registerer
type Config interface {
        Logger() logging.Logger
        RegistrationBasePrice() float64
        DataDSN() string
}

// Registerer validates the supplied person, calculates the price in the requested currency and saves the result.
// It will return an error when:
// -the person object does not include all the fields
// -the currency is invalid
// -the exchange rate cannot be loaded
// -the data layer throws an error.
type Registerer struct {
        cfg       Config
        exchanger Exchanger
        data      mySaver
}

// Do is API for this struct
func (r *Registerer) Do(ctx context.Context, in *Person) (int, error) {
        // validate the request
        err := r.validateInput(in)
        if err != nil {
                r.logger().Warn("input validation failed with err: %s", err)
                return defaultPersonID, err
        }

        // get price in the requested currency
        price, err := r.getPrice(ctx, in.Currency)
        if err != nil {
                return defaultPersonID, err
        }

        // save registration
        id, err := r.save(ctx, r.convert(in), price)
        if err != nil {
                // no need to log here as we expect the data layer to do so
                return defaultPersonID, err
        }

        return id, nil
}

// validate input and return error on fail
func (r *Registerer) validateInput(in *Person) error {
        if in.FullName == "" {
                return errNameMissing
        }
        if in.Phone == "" {
                return errPhoneMissing
        }
        if in.Currency == "" {
                return errCurrencyMissing
        }

        if _, found := supportedCurrencies[in.Currency]; !found {
                return errInvalidCurrency
        }

        // happy path
        return nil
}

// get price in the requested currency
func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) {
        price, err := r.exchanger.Exchange(ctx, r.cfg.RegistrationBasePrice(), currency)
        if err != nil {
                r.logger().Warn("failed to convert the price. err: %s", err)
                return defaultPersonID, err
        }

        return price, nil
}

// save the registration
func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) {
        person := &data.Person{
                FullName: in.FullName,
                Phone:    in.Phone,
                Currency: in.Currency,
                Price:    price,
        }
        return r.getSaver().Save(ctx, person)
}

func (r *Registerer) getSaver() mySaver {
        if r.data == nil {
                r.data = data.NewDAO(r.cfg)
        }

        return r.data
}

func (r *Registerer) logger() logging.Logger {
        return r.cfg.Logger()
}

func (r *Registerer) convert(in *Person) *data.Person {
        return &data.Person{
                ID:       in.ID,
                Currency: in.Currency,
                FullName: in.FullName,
                Phone:    in.Phone,
                Price:    in.Price,
        }
}

//go:generate mockery -name=mySaver -case underscore -testonly -inpkg -note @generated
type mySaver interface {
        Save(ctx context.Context, in *data.Person) (int, error)
}

// Person is a copy/sub-set of data.Person so that the relationship does not leak.
// It also allows us to remove/hide and internal fields
type Person struct {
        ID       int
        FullName string
        Phone    string
        Currency string
        Price    float64
}
================================================ FILE: ch12/03_testing/04_coverage_rest.htm ================================================
not tracked not covered covered
package rest

import (
        "encoding/json"
        "errors"
        "fmt"
        "io"
        "net/http"
        "strconv"

        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging"
        "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get"
        "github.com/gorilla/mux"
)

const (
        // default person id (returned on error)
        defaultPersonID = 0

        // key in the mux where the ID is stored
        muxVarID = "id"
)

// GetModel will load a registration
//go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated
type GetModel interface {
        Do(ID int) (*get.Person, error)
}

// GetConfig is the config for the Get Handler
type GetConfig interface {
        Logger() logging.Logger
}

// NewGetHandler is the constructor for GetHandler
func NewGetHandler(cfg GetConfig, model GetModel) *GetHandler {
        return &GetHandler{
                cfg:    cfg,
                getter: model,
        }
}

// GetHandler is the HTTP handler for the "Get Person" endpoint
// In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400
// or "not found" HTTP 404
// There are some programmer errors possible but hopefully these will be caught in testing.
type GetHandler struct {
        cfg    GetConfig
        getter GetModel
}

// ServeHTTP implements http.Handler
func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) {
        // extract person id from request
        id, err := h.extractID(request)
        if err != nil {
                // output error
                response.WriteHeader(http.StatusBadRequest)
                return
        }

        // attempt get
        person, err := h.getter.Do(id)
        if err != nil {
                // not need to log here as we can expect other layers to do so
                response.WriteHeader(http.StatusNotFound)
                return
        }

        // happy path
        err = h.writeJSON(response, person)
        if err != nil {
                // this error should not happen but if it does there is nothing we can do to recover
                response.WriteHeader(http.StatusInternalServerError)
        }
}

// extract the person ID from the request
func (h *GetHandler) extractID(request *http.Request) (int, error) {
        // ID is part of the URL, so we extract it from there
        vars := mux.Vars(request)
        idAsString, exists := vars[muxVarID]
        if !exists {
                // log and return error
                err := errors.New("[get] person id missing from request")
                h.cfg.Logger().Warn(err.Error())
                return defaultPersonID, err
        }

        // convert ID to int
        id, err := strconv.Atoi(idAsString)
        if err != nil {
                // log and return error
                err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err)
                h.cfg.Logger().Error(err.Error())
                return defaultPersonID, err
        }

        return id, nil
}

// output the supplied person as JSON
func (h *GetHandler) writeJSON(writer io.Writer, person *get.Person) error {
        output := &getResponseFormat{
                ID:       person.ID,
                FullName: person.FullName,
                Phone:    person.Phone,
                Currency: person.Currency,
                Price:    person.Price,
        }

        // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well
        return json.NewEncoder(writer).Encode(output)
}

// the JSON response format
type getResponseFormat struct {
        ID       int     `json:"id"`
        FullName string  `json:"name"`
        Phone    string  `json:"phone"`
        Currency string  `json:"currency"`
        Price    float64 `json:"price"`
}
================================================ FILE: ch12/04_new_service/01_data_with_cache/dao.go ================================================ package data import ( "context" "database/sql" "encoding/json" "fmt" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/04_new_service/01_data_with_cache/internal/cache" ) // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config db *sql.DB cache *cache.Cache } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { // load from cache out := d.loadFromCache(ID) if out != nil { return out, nil } // load from database row := d.db.QueryRowContext(ctx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { d.cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } // save person into the cache d.saveToCache(ID, out) return out, nil } func (d *DAO) loadFromCache(ID int) *Person { payload, err := d.cache.Get(d.buildCacheKey(ID)) if err != nil { d.cfg.Logger().Error("failed to load requested person from cache with error: %s", err) return nil } if payload == nil { return nil } out := &Person{} err = json.Unmarshal(payload, out) if err != nil { d.cfg.Logger().Error("failed to decode cache response with error: %s", err) } return out } func (d *DAO) saveToCache(ID int, person *Person) { payload, err := json.Marshal(person) if err != nil { d.cfg.Logger().Error("failed to encode person to JSON with error: %s", err) return } err = d.cache.Set(d.buildCacheKey(ID), payload) if err != nil { d.cfg.Logger().Error("failed to save person into cache with error: %s", err) } } func (d *DAO) buildCacheKey(ID int) string { return fmt.Sprintf("person-%d", ID) } ================================================ FILE: ch12/04_new_service/01_data_with_cache/data.go ================================================ package data import ( "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/04_new_service/01_data_with_cache/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // Logger returns a reference to the logger Logger() logging.Logger // DataDSN returns the data source name DataDSN() string } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch12/04_new_service/01_data_with_cache/internal/cache/cache.go ================================================ package cache import ( "errors" ) type Cache struct{} func (c *Cache) Get(key string) ([]byte, error) { return nil, errors.New("not implemented") } func (c *Cache) Set(key string, data []byte) error { return errors.New("not implemented") } ================================================ FILE: ch12/04_new_service/01_data_with_cache/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // Logger is our standard interface type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch12/acme/internal/config/config.go ================================================ package config import ( "encoding/json" "fmt" "io/ioutil" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" ) // DefaultEnvVar is the default environment variable the points to the config file const DefaultEnvVar = "ACME_CONFIG" // Config defines the JSON format for the config file type Config struct { // DSN is the data source name (format: https://github.com/go-sql-driver/mysql/#dsn-data-source-name) DSN string // Address is the IP address and port to bind this rest to Address string // BasePrice is the price of registration BasePrice float64 // ExchangeRateBaseURL is the server and protocol part of the URL from which to load the exchange rate ExchangeRateBaseURL string // ExchangeRateAPIKey is the API for the exchange rate API ExchangeRateAPIKey string // environmental dependencies logger logging.Logger } // Logger returns a reference to the singleton logger func (c *Config) Logger() logging.Logger { if c.logger == nil { c.logger = &logging.LoggerStdOut{} } return c.logger } // RegistrationBasePrice returns the base price for registrations func (c *Config) RegistrationBasePrice() float64 { return c.BasePrice } // DataDSN returns the DSN func (c *Config) DataDSN() string { return c.DSN } // ExchangeBaseURL returns the Base URL from which we can load exchange rates func (c *Config) ExchangeBaseURL() string { return c.ExchangeRateBaseURL } // ExchangeAPIKey returns the DSN func (c *Config) ExchangeAPIKey() string { return c.ExchangeRateAPIKey } // BindAddress returns the host and port this service should bind to func (c *Config) BindAddress() string { return c.Address } // Load returns the config loaded from environment func Load() (*Config, error) { filename, found := os.LookupEnv(DefaultEnvVar) if !found { err := fmt.Errorf("failed to locate file specified by %s", DefaultEnvVar) fmt.Fprintf(os.Stderr, err.Error()) return nil, err } cfg, err := load(filename) if err != nil { fmt.Fprintf(os.Stderr, "failed to load config with err %s", err) return nil, err } return cfg, nil } func load(filename string) (*Config, error) { out := &Config{} bytes, err := ioutil.ReadFile(filename) if err != nil { fmt.Fprintf(os.Stderr, "failed to read config file. err: %s", err) return nil, err } err = json.Unmarshal(bytes, out) if err != nil { fmt.Fprintf(os.Stderr, "failed to parse config file. err : %s", err) return nil, err } return out, nil } ================================================ FILE: ch12/acme/internal/config/config_test.go ================================================ package config import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLoad(t *testing.T) { scenarios := []struct { desc string in string expectedConfig *Config expectError bool }{ { desc: "happy path", in: "../../../../default-config.json", expectedConfig: &Config{ DSN: "[insert your db config here]", Address: "0.0.0.0:8080", BasePrice: 100.00, ExchangeRateBaseURL: "http://apilayer.net", ExchangeRateAPIKey: "[insert your API key here]", }, expectError: false, }, { desc: "invalid path", in: "invalid.json", expectedConfig: nil, expectError: true, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { result, resultErr := load(scenario.in) require.Equal(t, scenario.expectError, resultErr != nil, "err: %s", resultErr) assert.Equal(t, scenario.expectedConfig, result, scenario.desc) }) } } ================================================ FILE: ch12/acme/internal/logging/logging.go ================================================ package logging import ( "fmt" ) // Logger is our standard interface type Logger interface { Debug(message string, args ...interface{}) Info(message string, args ...interface{}) Warn(message string, args ...interface{}) Error(message string, args ...interface{}) } // LoggerStdOut logs to std out type LoggerStdOut struct{} // Debug logs messages at DEBUG level func (l LoggerStdOut) Debug(message string, args ...interface{}) { fmt.Printf("[DEBUG] "+message, args...) } // Info logs messages at INFO level func (l LoggerStdOut) Info(message string, args ...interface{}) { fmt.Printf("[INFO] "+message, args...) } // Warn logs messages at WARN level func (l LoggerStdOut) Warn(message string, args ...interface{}) { fmt.Printf("[WARN] "+message, args...) } // Error logs messages at ERROR level func (l LoggerStdOut) Error(message string, args ...interface{}) { fmt.Printf("[ERROR] "+message, args...) } ================================================ FILE: ch12/acme/internal/modules/data/dao.go ================================================ package data import ( "context" "database/sql" "time" ) // NewDAO will initialize the database connection pool (if not already done) and return a data access object which // can be used to interact with the database func NewDAO(cfg Config) *DAO { // initialize the db connection pool _, _ = getDB(cfg) return &DAO{ cfg: cfg, } } // DAO is a data access object that provides an abstraction over our database interactions. type DAO struct { cfg Config // Tracker is an optional query timer Tracker QueryTracker } // Load will attempt to load and return a person. // It will return ErrNotFound when the requested person does not exist. // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) Load(ctx context.Context, ID int) (*Person, error) { // track processing time defer d.getTracker().Track("Load", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select row := db.QueryRowContext(subCtx, sqlLoadByID, ID) // retrieve columns and populate the person object out, err := populatePerson(row.Scan) if err != nil { if err == sql.ErrNoRows { d.cfg.Logger().Warn("failed to load requested person '%d'. err: %s", ID, err) return nil, ErrNotFound } d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } return out, nil } // LoadAll will attempt to load all people in the database // It will return ErrNotFound when there are not people in the database // Any other errors returned are caused by the underlying database or our connection to it. func (d *DAO) LoadAll(ctx context.Context) ([]*Person, error) { // track processing time defer d.getTracker().Track("LoadAll", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return nil, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB select rows, err := db.QueryContext(subCtx, sqlLoadAll) if err != nil { return nil, err } defer func() { _ = rows.Close() }() var out []*Person for rows.Next() { // retrieve columns and populate the person object record, err := populatePerson(rows.Scan) if err != nil { d.cfg.Logger().Error("failed to convert query result. err: %s", err) return nil, err } out = append(out, record) } if len(out) == 0 { d.cfg.Logger().Warn("no people found in the database.") return nil, ErrNotFound } return out, nil } // Save will save the supplied person and return the ID of the newly created person or an error. // Errors returned are caused by the underlying database or our connection to it. func (d *DAO) Save(ctx context.Context, in *Person) (int, error) { // track processing time defer d.getTracker().Track("Save", time.Now()) db, err := getDB(d.cfg) if err != nil { d.cfg.Logger().Error("failed to get DB connection. err: %s", err) return defaultPersonID, err } // set latency budget for the database call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // perform DB insert result, err := db.ExecContext(subCtx, sqlInsert, in.FullName, in.Phone, in.Currency, in.Price) if err != nil { d.cfg.Logger().Error("failed to save person into DB. err: %s", err) return defaultPersonID, err } // retrieve and return the ID of the person created id, err := result.LastInsertId() if err != nil { d.cfg.Logger().Error("failed to retrieve id of last saved person. err: %s", err) return defaultPersonID, err } return int(id), nil } func (d *DAO) getTracker() QueryTracker { if d.Tracker == nil { d.Tracker = &noopTracker{} } return d.Tracker } ================================================ FILE: ch12/acme/internal/modules/data/data.go ================================================ package data import ( "database/sql" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" _ "github.com/go-sql-driver/mysql" ) const ( // default person id (returned on error) defaultPersonID = 0 // SQL statements as constants (to reduce duplication and maintenance in tests) sqlAllColumns = "id, fullname, phone, currency, price" sqlInsert = "INSERT INTO person (fullname, phone, currency, price) VALUES (?, ?, ?, ?)" sqlLoadAll = "SELECT " + sqlAllColumns + " FROM person" sqlLoadByID = "SELECT " + sqlAllColumns + " FROM person WHERE id = ? LIMIT 1" ) var ( db *sql.DB // ErrNotFound is returned when the no records where matched by the query ErrNotFound = errors.New("not found") ) // Config is the configuration for the data package type Config interface { // Logger returns a reference to the logger Logger() logging.Logger // DataDSN returns the data source name DataDSN() string } var getDB = func(cfg Config) (*sql.DB, error) { if db == nil { var err error db, err = sql.Open("mysql", cfg.DataDSN()) if err != nil { // if the DB cannot be accessed we are dead panic(err.Error()) } } return db, nil } // Person is the data transfer object (DTO) for this package type Person struct { // ID is the unique ID for this person ID int // FullName is the name of this person FullName string // Phone is the phone for this person Phone string // Currency is the currency this person has paid in Currency string // Price is the amount (in the above currency) paid by this person Price float64 } // custom type so we can convert sql results to easily type scanner func(dest ...interface{}) error // reduce the duplication (and maintenance) between sql.Row and sql.Rows usage func populatePerson(scanner scanner) (*Person, error) { out := &Person{} err := scanner(&out.ID, &out.FullName, &out.Phone, &out.Currency, &out.Price) return out, err } ================================================ FILE: ch12/acme/internal/modules/data/data_test.go ================================================ package data import ( "context" "database/sql" "errors" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSave_happyPath(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnResult(sqlmock.NewResult(2, 1)) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.NoError(t, err) assert.Equal(t, 2, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_insertError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // define a mock db testDb, dbMock, err := sqlmock.New() defer testDb.Close() require.NoError(t, err) // configure the mock db queryRegex := convertSQLToRegex(sqlInsert) dbMock.ExpectExec(queryRegex).WillReturnError(errors.New("failed to insert")) // monkey patching starts here db = testDb // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) // validate result require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) assert.NoError(t, dbMock.ExpectationsWereMet()) } func TestSave_getDBError(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // monkey patching starts here defer func(original func(_ Config) (*sql.DB, error)) { // restore original DB (after test) getDB = original }(getDB) // replace getDB() function for this test getDB = func(_ Config) (*sql.DB, error) { return nil, errors.New("getDB() failed") } // end of monkey patch // inputs in := &Person{ FullName: "Jake Blues", Phone: "01234567890", Currency: "AUD", Price: 123.45, } // call function dao := NewDAO(&testConfig{}) resultID, err := dao.Save(ctx, in) require.Error(t, err) assert.Equal(t, defaultPersonID, resultID) } func TestLoadAll_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResults []*Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(1, "John", "0123456789", "AUD", 12.34)) }, expectedResults: []*Person{ { ID: 1, FullName: "John", Phone: "0123456789", Currency: "AUD", Price: 12.34, }, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResults: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey patch the db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) results, err := dao.LoadAll(ctx) // validate results assert.Equal(t, scenario.expectedResults, results, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } func TestLoad_tableDrivenTest(t *testing.T) { // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() scenarios := []struct { desc string configureMockDB func(sqlmock.Sqlmock) expectedResult *Person expectError bool }{ { desc: "happy path", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnRows( sqlmock.NewRows(strings.Split(sqlAllColumns, ", ")). AddRow(2, "Paul", "0123456789", "CAD", 23.45)) }, expectedResult: &Person{ ID: 2, FullName: "Paul", Phone: "0123456789", Currency: "CAD", Price: 23.45, }, expectError: false, }, { desc: "load error", configureMockDB: func(dbMock sqlmock.Sqlmock) { queryRegex := convertSQLToRegex(sqlLoadAll) dbMock.ExpectQuery(queryRegex).WillReturnError(errors.New("something failed")) }, expectedResult: nil, expectError: true, }, } for _, scenario := range scenarios { // define a mock db testDb, dbMock, err := sqlmock.New() require.NoError(t, err) // configure the mock db scenario.configureMockDB(dbMock) // monkey db for this test original := *db db = testDb // call function dao := NewDAO(&testConfig{}) result, err := dao.Load(ctx, 2) // validate results assert.Equal(t, scenario.expectedResult, result, scenario.desc) assert.Equal(t, scenario.expectError, err != nil, scenario.desc) assert.NoError(t, dbMock.ExpectationsWereMet()) // restore original DB (after test) db = &original testDb.Close() } } // convert SQL string to regex by treating the entire query as a literal func convertSQLToRegex(in string) string { return `\Q` + in + `\E` } type testConfig struct{} // Logger implements Config func (t *testConfig) Logger() logging.Logger { return logging.LoggerStdOut{} } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } ================================================ FILE: ch12/acme/internal/modules/data/tracker.go ================================================ package data import ( "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" ) // QueryTracker is an interface to track query timing type QueryTracker interface { // Track will record/out the time a query took by calculating time.Now().Sub(start) Track(key string, start time.Time) } // NO-OP implementation of QueryTracker type noopTracker struct{} // Track implements QueryTracker func (_ *noopTracker) Track(_ string, _ time.Time) { // intentionally does nothing } // NewLogTracker returns a Tracker that outputs tracking data to log func NewLogTracker(logger logging.Logger) *LogTracker { return &LogTracker{ logger: logger, } } // LogTracker implements QueryTracker and outputs to the supplied logger type LogTracker struct { logger logging.Logger } // Track implements QueryTracker func (l *LogTracker) Track(key string, start time.Time) { l.logger.Info("[%s] Timing: %s\n", key, time.Now().Sub(start).String()) } ================================================ FILE: ch12/acme/internal/modules/exchange/converter.go ================================================ package exchange import ( "context" "encoding/json" "fmt" "io/ioutil" "math" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" ) const ( // request URL for the exchange rate API urlFormat = "%s/api/historical?access_key=%s&date=2018-06-20¤cies=%s" // default price that is sent when an error occurs defaultPrice = 0.0 ) // NewConverter creates and initializes the converter func NewConverter(cfg Config) *Converter { return &Converter{ cfg: cfg, } } // Config is the config for Converter type Config interface { Logger() logging.Logger ExchangeBaseURL() string ExchangeAPIKey() string } // Converter will convert the base price to the currency supplied // Note: we are expecting sane inputs and therefore skipping input validation type Converter struct { cfg Config } // Exchange will perform the conversion func (c *Converter) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { // load rate from the external API response, err := c.loadRateFromServer(ctx, currency) if err != nil { return defaultPrice, err } // extract rate from response rate, err := c.extractRate(response, currency) if err != nil { return defaultPrice, err } // apply rate and round to 2 decimal places return math.Floor((basePrice/rate)*100) / 100, nil } // load rate from the external API func (c *Converter) loadRateFromServer(ctx context.Context, currency string) (*http.Response, error) { // build the request url := fmt.Sprintf(urlFormat, c.cfg.ExchangeBaseURL(), c.cfg.ExchangeAPIKey(), currency) // perform request req, err := http.NewRequest("GET", url, nil) if err != nil { c.logger().Warn("[exchange] failed to create request. err: %s", err) return nil, err } // set latency budget for the upstream call subCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() // replace the default context with our custom one req = req.WithContext(subCtx) // perform the HTTP request response, err := http.DefaultClient.Do(req) if err != nil { c.logger().Warn("[exchange] failed to load. err: %s", err) return nil, err } if response.StatusCode != http.StatusOK { err = fmt.Errorf("request failed with code %d", response.StatusCode) c.logger().Warn("[exchange] %s", err) return nil, err } return response, nil } func (c *Converter) extractRate(response *http.Response, currency string) (float64, error) { defer func() { _ = response.Body.Close() }() // extract data from response data, err := c.extractResponse(response) if err != nil { return defaultPrice, err } // pull rate from response data rate, found := data.Quotes["USD"+currency] if !found { err = fmt.Errorf("response did not include expected currency '%s'", currency) c.logger().Error("[exchange] %s", err) return defaultPrice, err } // happy path return rate, nil } func (c *Converter) extractResponse(response *http.Response) (*apiResponseFormat, error) { payload, err := ioutil.ReadAll(response.Body) if err != nil { c.logger().Error("[exchange] failed to ready response body. err: %s", err) return nil, err } data := &apiResponseFormat{} err = json.Unmarshal(payload, data) if err != nil { c.logger().Error("[exchange] error converting response. err: %s", err) return nil, err } // happy path return data, nil } func (c *Converter) logger() logging.Logger { return c.cfg.Logger() } // the response format from the exchange rate API type apiResponseFormat struct { Quotes map[string]float64 `json:"quotes"` } ================================================ FILE: ch12/acme/internal/modules/exchange/converter_ext_bounday_test.go ================================================ // +build external package exchange import ( "context" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestExternalBoundaryTest(t *testing.T) { // define the config cfg, err := config.Load() require.NoError(t, err) // create a converter to test converter := NewConverter(cfg) // fetch from the server response, err := converter.loadRateFromServer(context.Background(), "AUD") require.NotNil(t, response) require.NoError(t, err) // parse the response resultRate, err := converter.extractRate(response, "AUD") require.NoError(t, err) // validate the result assert.True(t, resultRate > 0) } ================================================ FILE: ch12/acme/internal/modules/exchange/converter_int_bounday_test.go ================================================ package exchange import ( "context" "net/http" "net/http/httptest" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/stretchr/testify/assert" ) func TestInternalBoundaryTest(t *testing.T) { // start our test server server := httptest.NewServer(&happyExchangeRateService{}) defer server.Close() // define the config cfg := &testConfig{ baseURL: server.URL, apiKey: "", } // create a converter to test converter := NewConverter(cfg) resultRate, resultErr := converter.Exchange(context.Background(), 100.00, "AUD") // validate the result assert.Equal(t, 101.01, resultRate) assert.NoError(t, resultErr) } type happyExchangeRateService struct{} // ServeHTTP implements http.Handler func (*happyExchangeRateService) ServeHTTP(response http.ResponseWriter, request *http.Request) { payload := []byte(` { "success":true, "historical":true, "date":"2010-11-09", "timestamp":1289347199, "source":"USD", "quotes":{ "USDAUD":0.989981 } }`) response.Write(payload) } func TestExchange_invalidResponseFromServer(t *testing.T) { // start our test server server := httptest.NewServer(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { payload := []byte(`invalid payload`) response.Write(payload) })) defer server.Close() // inputs ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() basePrice := 12.34 currency := "AUD" cfg := &testConfig{ baseURL: server.URL, apiKey: "", } // create a converter to test converter := NewConverter(cfg) result, resultErr := converter.Exchange(ctx, basePrice, currency) // validate response assert.Equal(t, float64(0), result) assert.Error(t, resultErr) } // test implementation of Config type testConfig struct { baseURL string apiKey string } // Logger implements Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // ExchangeBaseURL implements Config func (t *testConfig) ExchangeBaseURL() string { return t.baseURL } // ExchangeAPIKey implements Config func (t *testConfig) ExchangeAPIKey() string { return t.apiKey } ================================================ FILE: ch12/acme/internal/modules/get/get.go ================================================ package get import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" ) var ( // error thrown when the requested person is not in the database errPersonNotFound = errors.New("person not found") ) // NewGetter creates and initializes a Getter func NewGetter(cfg Config) *Getter { return &Getter{ cfg: cfg, } } // Config is the configuration for Getter type Config interface { Logger() logging.Logger DataDSN() string } // Getter will attempt to load a person. // It can return an error caused by the data layer or when the requested person is not found type Getter struct { cfg Config data myLoader } // Do will perform the get func (g *Getter) Do(ID int) (*Person, error) { // load person from the data layer person, err := g.getLoader().Load(context.TODO(), ID) if err != nil { if err == data.ErrNotFound { // By converting the error we are hiding the implementation details from our users. return nil, errPersonNotFound } return nil, err } return g.convert(person), err } func (g *Getter) getLoader() myLoader { if g.data == nil { g.data = data.NewDAO(g.cfg) } return g.data } func (g *Getter) convert(in *data.Person) *Person { return &Person{ ID: in.ID, Currency: in.Currency, FullName: in.FullName, Phone: in.Phone, Price: in.Price, } } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { Load(ctx context.Context, ID int) (*data.Person, error) } // Person is a copy/sub-set of data.Person so that the relationship does not leak. // It also allows us to remove/hide and internal fields type Person struct { ID int FullName string Phone string Currency string Price float64 } ================================================ FILE: ch12/acme/internal/modules/get/go_test.go ================================================ package get import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetter_Do_happyPath(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockResult := &data.Person{ ID: 1234, FullName: "Doug", } mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(mockResult, nil).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.NoError(t, err) assert.Equal(t, ID, person.ID) assert.Equal(t, "Doug", person.FullName) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_noSuchPerson(t *testing.T) { // inputs ID := 5678 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, data.ErrNotFound).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Equal(t, errPersonNotFound, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } func TestGetter_Do_error(t *testing.T) { // inputs ID := 1234 // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("Load", mock.Anything, ID).Return(nil, errors.New("something failed")).Once() // call method getter := &Getter{ data: mockLoader, } person, err := getter.Do(ID) // validate expectations require.Error(t, err) assert.Nil(t, person) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch12/acme/internal/modules/get/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package get import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // Load provides a mock function with given fields: ctx, ID func (_m *mockMyLoader) Load(ctx context.Context, ID int) (*data.Person, error) { ret := _m.Called(ctx, ID) var r0 *data.Person if rf, ok := ret.Get(0).(func(context.Context, int) *data.Person); ok { r0 = rf(ctx, ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { r1 = rf(ctx, ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/modules/list/list.go ================================================ package list import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" ) var ( // error thrown when there are no people in the database errPeopleNotFound = errors.New("no people found") ) // NewLister creates and initializes a Lister func NewLister(cfg Config) *Lister { return &Lister{ cfg: cfg, } } // Config is the config for Lister type Config interface { Logger() logging.Logger DataDSN() string } // Lister will attempt to load all people in the database. // It can return an error caused by the data layer type Lister struct { cfg Config data myLoader } // Exchange will load the people from the data layer func (l *Lister) Do() ([]*Person, error) { // load all people people, err := l.load() if err != nil { return nil, err } if len(people) == 0 { // special processing for 0 people returned return nil, errPeopleNotFound } return l.convert(people), nil } // load all people func (l *Lister) load() ([]*data.Person, error) { people, err := l.getLoader().LoadAll(context.TODO()) if err != nil { if err == data.ErrNotFound { // By converting the error we are encapsulating the implementation details from our users. return nil, errPeopleNotFound } return nil, err } return people, nil } func (l *Lister) getLoader() myLoader { if l.data == nil { l.data = data.NewDAO(l.cfg) // temporarily add a log tracker l.data.(*data.DAO).Tracker = data.NewLogTracker(l.cfg.Logger()) } return l.data } func (l *Lister) convert(in []*data.Person) []*Person { out := make([]*Person, len(in)) for index, thisRecord := range in { out[index] = &Person{ ID: thisRecord.ID, FullName: thisRecord.FullName, Phone: thisRecord.Phone, } } return out } //go:generate mockery -name=myLoader -case underscore -testonly -inpkg -note @generated type myLoader interface { LoadAll(ctx context.Context) ([]*data.Person, error) } // Person is a copy/sub-set of data.Person so that the relationship does not leak. // It also allows us to remove/hide and internal fields type Person struct { ID int FullName string Phone string } ================================================ FILE: ch12/acme/internal/modules/list/list_test.go ================================================ package list import ( "errors" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestLister_Do_happyPath(t *testing.T) { // configure the mock loader mockResult := []*data.Person{ { ID: 1234, FullName: "Sally", }, { ID: 5678, FullName: "Jane", }, } mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(mockResult, nil).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.NoError(t, err) assert.Equal(t, 2, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_noResults(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, data.ErrNotFound).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Equal(t, errPeopleNotFound, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } func TestLister_Do_error(t *testing.T) { // configure the mock loader mockLoader := &mockMyLoader{} mockLoader.On("LoadAll", mock.Anything).Return(nil, errors.New("something failed")).Once() // call method lister := &Lister{ data: mockLoader, } persons, err := lister.load() // validate expectations require.Error(t, err) assert.Equal(t, 0, len(persons)) assert.True(t, mockLoader.AssertExpectations(t)) } ================================================ FILE: ch12/acme/internal/modules/list/mock_my_loader_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package list import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMyLoader is an autogenerated mock type for the myLoader type type mockMyLoader struct { mock.Mock } // LoadAll provides a mock function with given fields: ctx func (_m *mockMyLoader) LoadAll(ctx context.Context) ([]*data.Person, error) { ret := _m.Called(ctx) var r0 []*data.Person if rf, ok := ret.Get(0).(func(context.Context) []*data.Person); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*data.Person) } } var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/modules/register/mock_exchanger_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package register import ( "context" "github.com/stretchr/testify/mock" ) // MockExchanger is an autogenerated mock type for the Exchanger type type MockExchanger struct { mock.Mock } // Exchange provides a mock function with given fields: ctx, basePrice, currency func (_m *MockExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { ret := _m.Called(ctx, basePrice, currency) var r0 float64 if rf, ok := ret.Get(0).(func(context.Context, float64, string) float64); ok { r0 = rf(ctx, basePrice, currency) } else { r0 = ret.Get(0).(float64) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, float64, string) error); ok { r1 = rf(ctx, basePrice, currency) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/modules/register/mock_my_saver_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package register import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" "github.com/stretchr/testify/mock" ) // mockMySaver is an autogenerated mock type for the mySaver type type mockMySaver struct { mock.Mock } // Save provides a mock function with given fields: ctx, in func (_m *mockMySaver) Save(ctx context.Context, in *data.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *data.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *data.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/modules/register/register.go ================================================ package register import ( "context" "errors" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/data" ) const ( // default person id (returned on error) defaultPersonID = 0 ) var ( // validation errors errNameMissing = errors.New("name is missing") errPhoneMissing = errors.New("phone is missing") errCurrencyMissing = errors.New("currency is missing") errInvalidCurrency = errors.New("currency is invalid, supported types are AUD, CNY, EUR, GBP, JPY, MYR, SGD, USD") // a little trick to make checking for supported currencies easier supportedCurrencies = map[string]struct{}{ "AUD": {}, "CNY": {}, "EUR": {}, "GBP": {}, "JPY": {}, "MYR": {}, "SGD": {}, "USD": {}, } ) // NewRegisterer creates and initializes a Registerer func NewRegisterer(cfg Config, exchanger Exchanger) *Registerer { return &Registerer{ cfg: cfg, exchanger: exchanger, } } // Exchanger will convert from one currency to another //go:generate mockery -name=Exchanger -case underscore -testonly -inpkg -note @generated type Exchanger interface { // Exchange will perform the conversion Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) } // Config is the configuration for the Registerer type Config interface { Logger() logging.Logger RegistrationBasePrice() float64 DataDSN() string } // Registerer validates the supplied person, calculates the price in the requested currency and saves the result. // It will return an error when: // -the person object does not include all the fields // -the currency is invalid // -the exchange rate cannot be loaded // -the data layer throws an error. type Registerer struct { cfg Config exchanger Exchanger data mySaver } // Do is API for this struct func (r *Registerer) Do(ctx context.Context, in *Person) (int, error) { // validate the request err := r.validateInput(in) if err != nil { r.logger().Warn("input validation failed with err: %s", err) return defaultPersonID, err } // get price in the requested currency price, err := r.getPrice(ctx, in.Currency) if err != nil { return defaultPersonID, err } // save registration id, err := r.save(ctx, r.convert(in), price) if err != nil { // no need to log here as we expect the data layer to do so return defaultPersonID, err } return id, nil } // validate input and return error on fail func (r *Registerer) validateInput(in *Person) error { if in.FullName == "" { return errNameMissing } if in.Phone == "" { return errPhoneMissing } if in.Currency == "" { return errCurrencyMissing } if _, found := supportedCurrencies[in.Currency]; !found { return errInvalidCurrency } // happy path return nil } // get price in the requested currency func (r *Registerer) getPrice(ctx context.Context, currency string) (float64, error) { price, err := r.exchanger.Exchange(ctx, r.cfg.RegistrationBasePrice(), currency) if err != nil { r.logger().Warn("failed to convert the price. err: %s", err) return defaultPersonID, err } return price, nil } // save the registration func (r *Registerer) save(ctx context.Context, in *data.Person, price float64) (int, error) { person := &data.Person{ FullName: in.FullName, Phone: in.Phone, Currency: in.Currency, Price: price, } return r.getSaver().Save(ctx, person) } func (r *Registerer) getSaver() mySaver { if r.data == nil { r.data = data.NewDAO(r.cfg) } return r.data } func (r *Registerer) logger() logging.Logger { return r.cfg.Logger() } func (r *Registerer) convert(in *Person) *data.Person { return &data.Person{ ID: in.ID, Currency: in.Currency, FullName: in.FullName, Phone: in.Phone, Price: in.Price, } } //go:generate mockery -name=mySaver -case underscore -testonly -inpkg -note @generated type mySaver interface { Save(ctx context.Context, in *data.Person) (int, error) } // Person is a copy/sub-set of data.Person so that the relationship does not leak. // It also allows us to remove/hide and internal fields type Person struct { ID int FullName string Phone string Currency string Price float64 } ================================================ FILE: ch12/acme/internal/modules/register/register_test.go ================================================ package register import ( "context" "errors" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterer_Do_happyPath(t *testing.T) { // configure the mock saver mockResult := 888 mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(mockResult, nil).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.NoError(t, err) assert.Equal(t, 888, ID) assert.True(t, mockSaver.AssertExpectations(t)) } func TestRegisterer_Do_error(t *testing.T) { // configure the mock saver mockSaver := &mockMySaver{} mockSaver.On("Save", mock.Anything, mock.Anything).Return(0, errors.New("something failed")).Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: &stubExchanger{}, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) assert.True(t, mockSaver.AssertExpectations(t)) } func TestRegisterer_Do_exchangeError(t *testing.T) { // configure the mocks mockSaver := &mockMySaver{} mockExchanger := &MockExchanger{} mockExchanger. On("Exchange", mock.Anything, mock.Anything, mock.Anything). Return(0.0, errors.New("failed to load conversion")). Once() // define context and therefore test timeout ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // inputs in := &Person{ FullName: "Chang", Phone: "11122233355", Currency: "CNY", } // call method registerer := &Registerer{ cfg: &testConfig{}, exchanger: mockExchanger, data: mockSaver, } ID, err := registerer.Do(ctx, in) // validate expectations require.Error(t, err) assert.Equal(t, 0, ID) assert.True(t, mockSaver.AssertExpectations(t)) assert.True(t, mockExchanger.AssertExpectations(t)) } // Stub implementation of Config type testConfig struct{} // Logger implement Config func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } // RegistrationBasePrice implement Config func (t *testConfig) RegistrationBasePrice() float64 { return 12.34 } // DataDSN implements Config func (t *testConfig) DataDSN() string { return "" } type stubExchanger struct{} // Exchange implements Exchanger func (s stubExchanger) Exchange(ctx context.Context, basePrice float64, currency string) (float64, error) { return 12.34, nil } ================================================ FILE: ch12/acme/internal/rest/get.go ================================================ package rest import ( "encoding/json" "errors" "fmt" "io" "net/http" "strconv" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get" "github.com/gorilla/mux" ) const ( // default person id (returned on error) defaultPersonID = 0 // key in the mux where the ID is stored muxVarID = "id" ) // GetModel will load a registration //go:generate mockery -name=GetModel -case underscore -testonly -inpkg -note @generated type GetModel interface { Do(ID int) (*get.Person, error) } // GetConfig is the config for the Get Handler type GetConfig interface { Logger() logging.Logger } // NewGetHandler is the constructor for GetHandler func NewGetHandler(cfg GetConfig, model GetModel) *GetHandler { return &GetHandler{ cfg: cfg, getter: model, } } // GetHandler is the HTTP handler for the "Get Person" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400 // or "not found" HTTP 404 // There are some programmer errors possible but hopefully these will be caught in testing. type GetHandler struct { cfg GetConfig getter GetModel } // ServeHTTP implements http.Handler func (h *GetHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // extract person id from request id, err := h.extractID(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // attempt get person, err := h.getter.Do(id) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, person) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // extract the person ID from the request func (h *GetHandler) extractID(request *http.Request) (int, error) { // ID is part of the URL, so we extract it from there vars := mux.Vars(request) idAsString, exists := vars[muxVarID] if !exists { // log and return error err := errors.New("[get] person id missing from request") h.cfg.Logger().Warn(err.Error()) return defaultPersonID, err } // convert ID to int id, err := strconv.Atoi(idAsString) if err != nil { // log and return error err = fmt.Errorf("[get] failed to convert person id into a number. err: %s", err) h.cfg.Logger().Error(err.Error()) return defaultPersonID, err } return id, nil } // output the supplied person as JSON func (h *GetHandler) writeJSON(writer io.Writer, person *get.Person) error { output := &getResponseFormat{ ID: person.ID, FullName: person.FullName, Phone: person.Phone, Currency: person.Currency, Price: person.Price, } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } // the JSON response format type getResponseFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` Currency string `json:"currency"` Price float64 `json:"price"` } ================================================ FILE: ch12/acme/internal/rest/get_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestGetHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockGetModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { output := &get.Person{ ID: 1, FullName: "John", Phone: "0123456789", Currency: "USD", Price: 100, } mockGetModel := &MockGetModel{} mockGetModel.On("Do", mock.Anything).Return(output, nil).Once() return mockGetModel }, expectedStatus: http.StatusOK, expectedPayload: `{"id":1,"name":"John","phone":"0123456789","currency":"USD","price":100}` + "\n", }, { desc: "bad input (ID is invalid)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/x/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "x"}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "bad input (ID is missing)", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person//", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{}) }, inModelMock: func() *MockGetModel { // expect the model not to be called mockRegisterModel := &MockGetModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedPayload: ``, }, { desc: "dependency fail", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "requested registration does not exist", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/1/", nil) require.NoError(t, err) // set values into request (required by the mux) return mux.SetURLVars(req, map[string]string{muxVarID: "1"}) }, inModelMock: func() *MockGetModel { mockRegisterModel := &MockGetModel{} mockRegisterModel.On("Do", mock.Anything).Return(nil, errors.New("person not found")).Once() return mockRegisterModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockGetModel := scenario.inModelMock() // build handler handler := NewGetHandler(&testConfig{}, mockGetModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } type testConfig struct { } func (t *testConfig) Logger() logging.Logger { return &logging.LoggerStdOut{} } func (*testConfig) BindAddress() string { return "0.0.0.0:0" } ================================================ FILE: ch12/acme/internal/rest/list.go ================================================ package rest import ( "encoding/json" "io" "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list" ) // ListModel will load all registrations //go:generate mockery -name=ListModel -case underscore -testonly -inpkg -note @generated type ListModel interface { Do() ([]*list.Person, error) } // NewLister is the constructor for ListHandler func NewListHandler(model ListModel) *ListHandler { return &ListHandler{ lister: model, } } // ListHandler is the HTTP handler for the "List Do people" endpoint // In this simplified example we are assuming all possible errors are system errors (HTTP 500) type ListHandler struct { lister ListModel } // ServeHTTP implements http.Handler func (h *ListHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // attempt loadAll people, err := h.lister.Do() if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusNotFound) return } // happy path err = h.writeJSON(response, people) if err != nil { // this error should not happen but if it does there is nothing we can do to recover response.WriteHeader(http.StatusInternalServerError) } } // output the result as JSON func (h *ListHandler) writeJSON(writer io.Writer, people []*list.Person) error { output := &listResponseFormat{ People: make([]*listResponseItemFormat, len(people)), } for index, record := range people { output.People[index] = &listResponseItemFormat{ ID: record.ID, FullName: record.FullName, Phone: record.Phone, } } // call to http.ResponseWriter.Write() will cause HTTP OK (200) to be output as well return json.NewEncoder(writer).Encode(output) } type listResponseFormat struct { People []*listResponseItemFormat `json:"people"` } type listResponseItemFormat struct { ID int `json:"id"` FullName string `json:"name"` Phone string `json:"phone"` } ================================================ FILE: ch12/acme/internal/rest/list_test.go ================================================ package rest import ( "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestListHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockListModel expectedStatus int expectedPayload string }{ { desc: "happy path", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { output := []*list.Person{ { ID: 1, FullName: "John", Phone: "0123456789", }, { ID: 2, FullName: "Paul", Phone: "0123456781", }, { ID: 3, FullName: "George", Phone: "0123456782", }, { ID: 1, FullName: "Ringo", Phone: "0123456783", }, } mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[{"id":1,"name":"John","phone":"0123456789"},{"id":2,"name":"Paul","phone":"0123456781"},{"id":3,"name":"George","phone":"0123456782"},{"id":1,"name":"Ringo","phone":"0123456783"}]}` + "\n", }, { desc: "dependency failure", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(nil, errors.New("something failed")).Once() return mockListModel }, expectedStatus: http.StatusNotFound, expectedPayload: ``, }, { desc: "no data", inRequest: func() *http.Request { req, err := http.NewRequest("GET", "/person/list", nil) require.NoError(t, err) return req }, inModelMock: func() *MockListModel { // no data var output []*list.Person mockListModel := &MockListModel{} mockListModel.On("Do", mock.Anything).Return(output, nil).Once() return mockListModel }, expectedStatus: http.StatusOK, expectedPayload: `{"people":[]}` + "\n", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockListModel := scenario.inModelMock() // build handler handler := NewListHandler(mockListModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code, scenario.desc) payload, _ := ioutil.ReadAll(response.Body) assert.Equal(t, scenario.expectedPayload, string(payload), scenario.desc) }) } } ================================================ FILE: ch12/acme/internal/rest/mock_get_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get" "github.com/stretchr/testify/mock" ) // MockGetModel is an autogenerated mock type for the GetModel type type MockGetModel struct { mock.Mock } // Do provides a mock function with given fields: ID func (_m *MockGetModel) Do(ID int) (*get.Person, error) { ret := _m.Called(ID) var r0 *get.Person if rf, ok := ret.Get(0).(func(int) *get.Person); ok { r0 = rf(ID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*get.Person) } } var r1 error if rf, ok := ret.Get(1).(func(int) error); ok { r1 = rf(ID) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/rest/mock_list_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list" "github.com/stretchr/testify/mock" ) // MockListModel is an autogenerated mock type for the ListModel type type MockListModel struct { mock.Mock } // Do provides a mock function with given fields: func (_m *MockListModel) Do() ([]*list.Person, error) { ret := _m.Called() var r0 []*list.Person if rf, ok := ret.Get(0).(func() []*list.Person); ok { r0 = rf() } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*list.Person) } } var r1 error if rf, ok := ret.Get(1).(func() error); ok { r1 = rf() } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/rest/mock_register_model_test.go ================================================ // Code generated by mockery v1.0.0 // @generated package rest import ( "context" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/register" "github.com/stretchr/testify/mock" ) // MockRegisterModel is an autogenerated mock type for the RegisterModel type type MockRegisterModel struct { mock.Mock } // Do provides a mock function with given fields: ctx, in func (_m *MockRegisterModel) Do(ctx context.Context, in *register.Person) (int, error) { ret := _m.Called(ctx, in) var r0 int if rf, ok := ret.Get(0).(func(context.Context, *register.Person) int); ok { r0 = rf(ctx, in) } else { r0 = ret.Get(0).(int) } var r1 error if rf, ok := ret.Get(1).(func(context.Context, *register.Person) error); ok { r1 = rf(ctx, in) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: ch12/acme/internal/rest/not_found.go ================================================ package rest import ( "net/http" ) func notFoundHandler(response http.ResponseWriter, _ *http.Request) { response.WriteHeader(http.StatusNotFound) _, _ = response.Write([]byte(`Not found`)) } ================================================ FILE: ch12/acme/internal/rest/not_found_test.go ================================================ package rest import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) func TestNotFoundHandler_ServeHTTP(t *testing.T) { // build inputs response := httptest.NewRecorder() request := &http.Request{} // call handler notFoundHandler(response, request) // validate outputs require.Equal(t, http.StatusNotFound, response.Code) } ================================================ FILE: ch12/acme/internal/rest/register.go ================================================ package rest import ( "context" "encoding/json" "fmt" "net/http" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/register" ) // RegisterModel will validate and save a registration //go:generate mockery -name=RegisterModel -case underscore -testonly -inpkg -note @generated type RegisterModel interface { Do(ctx context.Context, in *register.Person) (int, error) } // NewRegisterHandler is the constructor for RegisterHandler func NewRegisterHandler(model RegisterModel) *RegisterHandler { return &RegisterHandler{ registerer: model, } } // RegisterHandler is the HTTP handler for the "Register" endpoint // In this simplified example we are assuming all possible errors are user errors and returning "bad request" HTTP 400. // There are some programmer errors possible but hopefully these will be caught in testing. type RegisterHandler struct { registerer RegisterModel } // ServeHTTP implements http.Handler func (h *RegisterHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { // set latency budget for this API subCtx, cancel := context.WithTimeout(request.Context(), 1500*time.Millisecond) defer cancel() // extract payload from request requestPayload, err := h.extractPayload(request) if err != nil { // output error response.WriteHeader(http.StatusBadRequest) return } // call the business logic using the request data and context id, err := h.register(subCtx, requestPayload) if err != nil { // not need to log here as we can expect other layers to do so response.WriteHeader(http.StatusBadRequest) return } // happy path response.Header().Add("Location", fmt.Sprintf("/person/%d/", id)) response.WriteHeader(http.StatusCreated) } // extract payload from request func (h *RegisterHandler) extractPayload(request *http.Request) (*registerRequest, error) { requestPayload := ®isterRequest{} decoder := json.NewDecoder(request.Body) err := decoder.Decode(requestPayload) if err != nil { return nil, err } return requestPayload, nil } // call the logic layer func (h *RegisterHandler) register(ctx context.Context, requestPayload *registerRequest) (int, error) { person := ®ister.Person{ FullName: requestPayload.FullName, Phone: requestPayload.Phone, Currency: requestPayload.Currency, } return h.registerer.Do(ctx, person) } // register endpoint request format type registerRequest struct { // FullName of the person FullName string `json:"fullName"` // Phone of the person Phone string `json:"phone"` // Currency the wish to register in Currency string `json:"currency"` } ================================================ FILE: ch12/acme/internal/rest/register_test.go ================================================ package rest import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestRegisterHandler_ServeHTTP(t *testing.T) { scenarios := []struct { desc string inRequest func() *http.Request inModelMock func() *MockRegisterModel expectedStatus int expectedHeader string }{ { desc: "Happy Path", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // valid downstream configuration resultID := 1234 var resultErr error mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(resultID, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusCreated, expectedHeader: "/person/1234/", }, { desc: "Bad Input / User Error", inRequest: func() *http.Request { invalidRequest := bytes.NewBufferString(`this is not valid JSON`) request, err := http.NewRequest("POST", "/person/register", invalidRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // Dependency should not be called mockRegisterModel := &MockRegisterModel{} return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, { desc: "Dependency Failure", inRequest: func() *http.Request { validRequest := buildValidRegisterRequest() request, err := http.NewRequest("POST", "/person/register", validRequest) require.NoError(t, err) return request }, inModelMock: func() *MockRegisterModel { // call to the dependency failed resultErr := errors.New("something failed") mockRegisterModel := &MockRegisterModel{} mockRegisterModel.On("Do", mock.Anything, mock.Anything).Return(0, resultErr).Once() return mockRegisterModel }, expectedStatus: http.StatusBadRequest, expectedHeader: "", }, } for _, s := range scenarios { scenario := s t.Run(scenario.desc, func(t *testing.T) { // define model layer mock mockRegisterModel := scenario.inModelMock() // build handler handler := NewRegisterHandler(mockRegisterModel) // perform request response := httptest.NewRecorder() handler.ServeHTTP(response, scenario.inRequest()) // validate outputs require.Equal(t, scenario.expectedStatus, response.Code) // call should output the location to the new person resultHeader := response.Header().Get("Location") assert.Equal(t, scenario.expectedHeader, resultHeader) // validate the mock was used as we expected assert.True(t, mockRegisterModel.AssertExpectations(t)) }) } } func buildValidRegisterRequest() io.Reader { requestData := ®isterRequest{ FullName: "Joan Smith", Currency: "AUD", Phone: "01234567890", } data, _ := json.Marshal(requestData) return bytes.NewBuffer(data) } ================================================ FILE: ch12/acme/internal/rest/server.go ================================================ package rest import ( "net/http" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/logging" "github.com/gorilla/mux" ) // Config is the config for the REST package type Config interface { Logger() logging.Logger BindAddress() string } // New will create and initialize the server func New(cfg Config, getModel GetModel, listModel ListModel, registerModel RegisterModel) *Server { return &Server{ address: cfg.BindAddress(), handlerGet: NewGetHandler(cfg, getModel), handlerList: NewListHandler(listModel), handlerNotFound: notFoundHandler, handlerRegister: NewRegisterHandler(registerModel), } } // Server is the HTTP REST server type Server struct { address string server *http.Server handlerGet http.Handler handlerList http.Handler handlerNotFound http.HandlerFunc handlerRegister http.Handler } // Listen will start a HTTP rest for this service func (s *Server) Listen(stop <-chan struct{}) { router := s.buildRouter() // create the HTTP server s.server = &http.Server{ Handler: router, Addr: s.address, } // listen for shutdown go func() { // wait for shutdown signal <-stop _ = s.server.Close() }() // start the HTTP server _ = s.server.ListenAndServe() } // configure the endpoints to handlers func (s *Server) buildRouter() http.Handler { router := mux.NewRouter() // map URL endpoints to HTTP handlers router.Handle("/person/{id}/", s.handlerGet).Methods("GET") router.Handle("/person/list", s.handlerList).Methods("GET") router.Handle("/person/register", s.handlerRegister).Methods("POST") // convert a "catch all" not found handler router.NotFoundHandler = s.handlerNotFound return router } ================================================ FILE: ch12/acme/main.go ================================================ package main import ( "context" "os" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/rest" "github.com/google/wire" ) func main() { // bind stop channel to context ctx := context.Background() // start REST server server, err := initializeServer() if err != nil { os.Exit(-1) } server.Listen(ctx.Done()) } // List of wire enabled objects var wireSetWithoutConfig = wire.NewSet( // *exchange.Converter exchange.NewConverter, // *get.Getter get.NewGetter, // *list.Lister list.NewLister, // *register.Registerer wire.Bind(new(register.Exchanger), &exchange.Converter{}), register.NewRegisterer, // *rest.Server wire.Bind(new(rest.GetModel), &get.Getter{}), wire.Bind(new(rest.ListModel), &list.Lister{}), wire.Bind(new(rest.RegisterModel), ®ister.Registerer{}), rest.New, ) var wireSet = wire.NewSet( wireSetWithoutConfig, // *config.Config config.Load, // *exchange.Converter wire.Bind(new(exchange.Config), &config.Config{}), // *get.Getter wire.Bind(new(get.Config), &config.Config{}), // *list.Lister wire.Bind(new(list.Config), &config.Config{}), // *register.Registerer wire.Bind(new(register.Config), &config.Config{}), // *rest.Server wire.Bind(new(rest.Config), &config.Config{}), ) ================================================ FILE: ch12/acme/main_test.go ================================================ package main import ( "bytes" "context" "errors" "fmt" "net" "net/http" "testing" "time" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRegister(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request payload := bytes.NewBufferString(` { "fullName": "Bob", "phone": "0123456789", "currency": "AUD" } `) req, err := http.NewRequest("POST", serverAddress+"/person/register", payload) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusCreated, resp.StatusCode) assert.NotEmpty(t, resp.Header.Get("Location")) } func TestGet(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request req, err := http.NewRequest("GET", serverAddress+"/person/1/", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusOK, resp.StatusCode) } func TestList(t *testing.T) { // start a context with a max execution time ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // start test server serverAddress := startTestServer(t, ctx) // build and send request req, err := http.NewRequest("GET", serverAddress+"/person/list", nil) require.NoError(t, err) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) // validate expectations assert.Equal(t, http.StatusOK, resp.StatusCode) } func startTestServer(t *testing.T, ctx context.Context) string { // load the standard config (from the ENV) cfg, err := config.Load() require.NoError(t, err) // get a free port (so tests can run concurrently) port, err := getFreePort() require.NoError(t, err) // override config port with free one cfg.Address = net.JoinHostPort("0.0.0.0", port) // start the test server on a random port go func() { // start REST server server := initializeServerCustomConfig(cfg, cfg, cfg, cfg, cfg) server.Listen(ctx.Done()) }() // give the server a chance to start <-time.After(100 * time.Millisecond) // return the address of the test server return "http://" + cfg.Address } func getFreePort() (string, error) { for attempt := 0; attempt <= 10; attempt++ { addr := net.JoinHostPort("", "0") listener, err := net.Listen("tcp", addr) if err != nil { continue } port, err := getPort(listener.Addr()) if err != nil { continue } // close/free the port tcpListener := listener.(*net.TCPListener) cErr := tcpListener.Close() if cErr == nil { file, fErr := tcpListener.File() if fErr == nil { // ignore any errors cleaning up the file _ = file.Close() } return port, nil } } return "", errors.New("no free ports") } func getPort(addr fmt.Stringer) (string, error) { actualAddress := addr.String() _, port, err := net.SplitHostPort(actualAddress) return port, err } ================================================ FILE: ch12/acme/wire.go ================================================ //+build wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/rest" "github.com/google/wire" ) // The build tag makes sure the stub is not built in the final build. func initializeServer() (*rest.Server, error) { wire.Build(wireSet) return nil, nil } func initializeServerCustomConfig(_ exchange.Config, _ get.Config, _ list.Config, _ register.Config, _ rest.Config) *rest.Server { wire.Build(wireSetWithoutConfig) return nil } ================================================ FILE: ch12/acme/wire_gen.go ================================================ // Code generated by Wire. DO NOT EDIT. //go:generate wire //+build !wireinject package main import ( "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/config" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/exchange" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/get" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/list" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/modules/register" "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go/ch12/acme/internal/rest" ) // Injectors from wire.go: func initializeServer() (*rest.Server, error) { configConfig, err := config.Load() if err != nil { return nil, err } getter := get.NewGetter(configConfig) lister := list.NewLister(configConfig) converter := exchange.NewConverter(configConfig) registerer := register.NewRegisterer(configConfig, converter) server := rest.New(configConfig, getter, lister, registerer) return server, nil } func initializeServerCustomConfig(exchangeConfig exchange.Config, getConfig get.Config, listConfig list.Config, registerConfig register.Config, restConfig rest.Config) *rest.Server { getter := get.NewGetter(getConfig) lister := list.NewLister(listConfig) converter := exchange.NewConverter(exchangeConfig) registerer := register.NewRegisterer(registerConfig, converter) server := rest.New(restConfig, getter, lister, registerer) return server } ================================================ FILE: ch12/fake.go ================================================ package ch12 func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: default-config.json ================================================ { "dsn": "[insert your db config here]", "address": "0.0.0.0:8080", "basePrice": 100.00, "exchangeRateBaseURL": "http://apilayer.net", "exchangeRateAPIKey": "[insert your API key here]" } ================================================ FILE: fake.go ================================================ package Hands_On_Dependency_Injection_in_Go func init() { // This file is included so that Go tools (like `go list`) will find Go code in this directory and not error } ================================================ FILE: resources/create.sql ================================================ CREATE DATABASE IF NOT EXISTS acme; CREATE TABLE IF NOT EXISTS `acme`.`person` ( `id` BIGINT(20) NOT NULL AUTO_INCREMENT, `fullName` VARCHAR(100) NOT NULL, `phone` CHAR(15) NOT NULL, `currency` CHAR(3) NOT NULL, `price` DECIMAL(6,2) NOT NULL, PRIMARY KEY (`id`)); INSERT INTO `acme`.`person` (`id`, `fullName`, `phone`, `currency`, `price`) VALUES ("1", "John", "0123456780", "USD", 100); INSERT INTO `acme`.`person` (`id`, `fullName`, `phone`, `currency`, `price`) VALUES ("2", "Paul", "0123456781", "AUD", 120); INSERT INTO `acme`.`person` (`id`, `fullName`, `phone`, `currency`, `price`) VALUES ("3", "George", "0123456782", "GBP", 150); INSERT INTO `acme`.`person` (`id`, `fullName`, `phone`, `currency`, `price`) VALUES ("4", "Ringo", "0123456783", "EUR", 110); ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/LICENSE ================================================ The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) Copyright (c) 2013-2018, DATA-DOG team All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * The name DataDog.lt may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/README.md ================================================ [![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.svg)](https://travis-ci.org/DATA-DOG/go-sqlmock) [![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.svg)](https://godoc.org/github.com/DATA-DOG/go-sqlmock) [![Go Report Card](https://goreportcard.com/badge/github.com/DATA-DOG/go-sqlmock)](https://goreportcard.com/report/github.com/DATA-DOG/go-sqlmock) [![codecov.io](https://codecov.io/github/DATA-DOG/go-sqlmock/branch/master/graph/badge.svg)](https://codecov.io/github/DATA-DOG/go-sqlmock) # Sql driver mock for Golang **sqlmock** is a mock library implementing [sql/driver](https://godoc.org/database/sql/driver). Which has one and only purpose - to simulate any **sql** driver behavior in tests, without needing a real database connection. It helps to maintain correct **TDD** workflow. - this library is now complete and stable. (you may not find new changes for this reason) - supports concurrency and multiple connections. - supports **go1.8** Context related feature mocking and Named sql parameters. - does not require any modifications to your source code. - the driver allows to mock any sql driver method behavior. - has strict by default expectation order matching. - has no third party dependencies. **NOTE:** in **v1.2.0** **sqlmock.Rows** has changed to struct from interface, if you were using any type references to that interface, you will need to switch it to a pointer struct type. Also, **sqlmock.Rows** were used to implement **driver.Rows** interface, which was not required or useful for mocking and was removed. Hope it will not cause issues. ## Install go get gopkg.in/DATA-DOG/go-sqlmock.v1 ## Documentation and Examples Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference. See **.travis.yml** for supported **go** versions. Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb) all database related actions are isolated within a single transaction so the database can remain in the same state. See implementation examples: - [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog) - [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders) ### Something you may want to test ``` go package main import "database/sql" func recordStats(db *sql.DB, userID, productID int64) (err error) { tx, err := db.Begin() if err != nil { return } defer func() { switch err { case nil: err = tx.Commit() default: tx.Rollback() } }() if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil { return } if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil { return } return } func main() { // @NOTE: the real connection is not required for tests db, err := sql.Open("mysql", "root@/blog") if err != nil { panic(err) } defer db.Close() if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil { panic(err) } } ``` ### Tests with sqlmock ``` go package main import ( "fmt" "testing" "gopkg.in/DATA-DOG/go-sqlmock.v1" ) // a successful case func TestShouldUpdateStats(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() mock.ExpectBegin() mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() // now we execute our method if err = recordStats(db, 2, 3); err != nil { t.Errorf("error was not expected while updating stats: %s", err) } // we make sure that all expectations were met if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } // a failing test case func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() mock.ExpectBegin() mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec("INSERT INTO product_viewers"). WithArgs(2, 3). WillReturnError(fmt.Errorf("some error")) mock.ExpectRollback() // now we execute our method if err = recordStats(db, 2, 3); err == nil { t.Errorf("was expecting an error, but there was none") } // we make sure that all expectations were met if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } ``` ## Matching arguments like time.Time There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case **sqlmock** provides an [Argument](https://godoc.org/github.com/DATA-DOG/go-sqlmock#Argument) interface which can be used in more sophisticated matching. Here is a simple example of time argument matching: ``` go type AnyTime struct{} // Match satisfies sqlmock.Argument interface func (a AnyTime) Match(v driver.Value) bool { _, ok := v.(time.Time) return ok } func TestAnyTimeArgument(t *testing.T) { t.Parallel() db, mock, err := New() if err != nil { t.Errorf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() mock.ExpectExec("INSERT INTO users"). WithArgs("john", AnyTime{}). WillReturnResult(NewResult(1, 1)) _, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) if err != nil { t.Errorf("error '%s' was not expected, while inserting a row", err) } if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } ``` It only asserts that argument is of `time.Time` type. ## Run tests go test -race ## Change Log - **2017-09-01** - it is now possible to expect that prepared statement will be closed, using **ExpectedPrepare.WillBeClosed**. - **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now accept multiple row sets. - **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL query. It should still be validated even if Exec or Query is not executed on that prepared statement. - **2016-02-23** - added **sqlmock.AnyArg()** function to provide any kind of argument matcher. - **2016-02-23** - convert expected arguments to driver.Value as natural driver does, the change may affect time.Time comparison and will be stricter. See [issue](https://github.com/DATA-DOG/go-sqlmock/issues/31). - **2015-08-27** - **v1** api change, concurrency support, all known issues fixed. - **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error - **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5) - **2014-05-29** allow to match arguments in more sophisticated ways, by providing an **sqlmock.Argument** interface - **2014-04-21** introduce **sqlmock.New()** to open a mock database connection for tests. This method calls sql.DB.Ping to ensure that connection is open, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/4). This way on Close it will surely assert if all expectations are met, even if database was not triggered at all. The old way is still available, but it is advisable to call db.Ping manually before asserting with db.Close. - **2014-02-14** RowsFromCSVString is now a part of Rows interface named as FromCSVString. It has changed to allow more ways to construct rows and to easily extend this API in future. See [issue 1](https://github.com/DATA-DOG/go-sqlmock/issues/1) **RowsFromCSVString** is deprecated and will be removed in future ## Contributions Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) - please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are and will be treated cautiously ## License The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses) ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/argument.go ================================================ package sqlmock import "database/sql/driver" // Argument interface allows to match // any argument in specific way when used with // ExpectedQuery and ExpectedExec expectations. type Argument interface { Match(driver.Value) bool } // AnyArg will return an Argument which can // match any kind of arguments. // // Useful for time.Time or similar kinds of arguments. func AnyArg() Argument { return anyArgument{} } type anyArgument struct{} func (a anyArgument) Match(_ driver.Value) bool { return true } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/driver.go ================================================ package sqlmock import ( "database/sql" "database/sql/driver" "fmt" "sync" ) var pool *mockDriver func init() { pool = &mockDriver{ conns: make(map[string]*sqlmock), } sql.Register("sqlmock", pool) } type mockDriver struct { sync.Mutex counter int conns map[string]*sqlmock } func (d *mockDriver) Open(dsn string) (driver.Conn, error) { d.Lock() defer d.Unlock() c, ok := d.conns[dsn] if !ok { return c, fmt.Errorf("expected a connection to be available, but it is not") } c.opened++ return c, nil } // New creates sqlmock database connection // and a mock to manage expectations. // Pings db so that all expectations could be // asserted. func New() (*sql.DB, Sqlmock, error) { pool.Lock() dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) pool.counter++ smock := &sqlmock{dsn: dsn, drv: pool, ordered: true} pool.conns[dsn] = smock pool.Unlock() return smock.open() } // NewWithDSN creates sqlmock database connection // with a specific DSN and a mock to manage expectations. // Pings db so that all expectations could be asserted. // // This method is introduced because of sql abstraction // libraries, which do not provide a way to initialize // with sql.DB instance. For example GORM library. // // Note, it will error if attempted to create with an // already used dsn // // It is not recommended to use this method, unless you // really need it and there is no other way around. func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) { pool.Lock() if _, ok := pool.conns[dsn]; ok { pool.Unlock() return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn) } smock := &sqlmock{dsn: dsn, drv: pool, ordered: true} pool.conns[dsn] = smock pool.Unlock() return smock.open() } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/expectations.go ================================================ package sqlmock import ( "database/sql/driver" "fmt" "regexp" "strings" "sync" "time" ) // an expectation interface type expectation interface { fulfilled() bool Lock() Unlock() String() string } // common expectation struct // satisfies the expectation interface type commonExpectation struct { sync.Mutex triggered bool err error } func (e *commonExpectation) fulfilled() bool { return e.triggered } // ExpectedClose is used to manage *sql.DB.Close expectation // returned by *Sqlmock.ExpectClose. type ExpectedClose struct { commonExpectation } // WillReturnError allows to set an error for *sql.DB.Close action func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { e.err = err return e } // String returns string representation func (e *ExpectedClose) String() string { msg := "ExpectedClose => expecting database Close" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // ExpectedBegin is used to manage *sql.DB.Begin expectation // returned by *Sqlmock.ExpectBegin. type ExpectedBegin struct { commonExpectation delay time.Duration } // WillReturnError allows to set an error for *sql.DB.Begin action func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { e.err = err return e } // String returns string representation func (e *ExpectedBegin) String() string { msg := "ExpectedBegin => expecting database transaction Begin" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { e.delay = duration return e } // ExpectedCommit is used to manage *sql.Tx.Commit expectation // returned by *Sqlmock.ExpectCommit. type ExpectedCommit struct { commonExpectation } // WillReturnError allows to set an error for *sql.Tx.Close action func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { e.err = err return e } // String returns string representation func (e *ExpectedCommit) String() string { msg := "ExpectedCommit => expecting transaction Commit" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // ExpectedRollback is used to manage *sql.Tx.Rollback expectation // returned by *Sqlmock.ExpectRollback. type ExpectedRollback struct { commonExpectation } // WillReturnError allows to set an error for *sql.Tx.Rollback action func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { e.err = err return e } // String returns string representation func (e *ExpectedRollback) String() string { msg := "ExpectedRollback => expecting transaction Rollback" if e.err != nil { msg += fmt.Sprintf(", which should return error: %s", e.err) } return msg } // ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query, // *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations. // Returned by *Sqlmock.ExpectQuery. type ExpectedQuery struct { queryBasedExpectation rows driver.Rows delay time.Duration } // WithArgs will match given expected args to actual database query arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { e.args = args return e } // WillReturnError allows to set an error for expected database query func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { e.err = err return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { e.delay = duration return e } // String returns string representation func (e *ExpectedQuery) String() string { msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" if len(e.args) == 0 { msg += "\n - is without arguments" } else { msg += "\n - is with arguments:\n" for i, arg := range e.args { msg += fmt.Sprintf(" %d - %+v\n", i, arg) } msg = strings.TrimSpace(msg) } if e.rows != nil { msg += fmt.Sprintf("\n - %s", e.rows) } if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) } return msg } // ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations. // Returned by *Sqlmock.ExpectExec. type ExpectedExec struct { queryBasedExpectation result driver.Result delay time.Duration } // WithArgs will match given expected args to actual database exec operation arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { e.args = args return e } // WillReturnError allows to set an error for expected database exec action func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { e.err = err return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { e.delay = duration return e } // String returns string representation func (e *ExpectedExec) String() string { msg := "ExpectedExec => expecting Exec or ExecContext which:" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" if len(e.args) == 0 { msg += "\n - is without arguments" } else { msg += "\n - is with arguments:\n" var margs []string for i, arg := range e.args { margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg)) } msg += strings.Join(margs, "\n") } if e.result != nil { res, _ := e.result.(*result) msg += "\n - should return Result having:" msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID) msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected) if res.err != nil { msg += fmt.Sprintf("\n Error: %s", res.err) } } if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) } return msg } // WillReturnResult arranges for an expected Exec() to return a particular // result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method // to build a corresponding result. Or if actions needs to be tested against errors // sqlmock.NewErrorResult(err error) to return a given error. func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { e.result = result return e } // ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations. // Returned by *Sqlmock.ExpectPrepare. type ExpectedPrepare struct { commonExpectation mock *sqlmock sqlRegex *regexp.Regexp statement driver.Stmt closeErr error mustBeClosed bool wasClosed bool delay time.Duration } // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { e.err = err return e } // WillReturnCloseError allows to set an error for this prepared statement Close action func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { e.closeErr = err return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { e.delay = duration return e } // WillBeClosed expects this prepared statement to // be closed. func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { e.mustBeClosed = true return e } // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { eq := &ExpectedQuery{} eq.sqlRegex = e.sqlRegex e.mock.expected = append(e.mock.expected, eq) return eq } // ExpectExec allows to expect Exec() on this prepared statement. // this method is convenient in order to prevent duplicating sql query string matching. func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { eq := &ExpectedExec{} eq.sqlRegex = e.sqlRegex e.mock.expected = append(e.mock.expected, eq) return eq } // String returns string representation func (e *ExpectedPrepare) String() string { msg := "ExpectedPrepare => expecting Prepare statement which:" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" if e.err != nil { msg += fmt.Sprintf("\n - should return error: %s", e.err) } if e.closeErr != nil { msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr) } return msg } // query based expectation // adds a query matching logic type queryBasedExpectation struct { commonExpectation sqlRegex *regexp.Regexp args []driver.Value } func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) { if !e.queryMatches(sql) { return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String()) } // catch panic defer func() { if e := recover(); e != nil { _, ok := e.(error) if !ok { err = fmt.Errorf(e.(string)) } } }() err = e.argsMatches(args) return } func (e *queryBasedExpectation) queryMatches(sql string) bool { return e.sqlRegex.MatchString(sql) } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/expectations_before_go18.go ================================================ // +build !go1.8 package sqlmock import ( "database/sql/driver" "fmt" "reflect" ) // WillReturnRows specifies the set of resulting rows that will be returned // by the triggered query func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { e.rows = &rowSets{sets: []*Rows{rows}} return e } func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil } if len(args) != len(e.args) { return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) } for k, v := range args { // custom argument matcher matcher, ok := e.args[k].(Argument) if ok { // @TODO: does it make sense to pass value instead of named value? if !matcher.Match(v.Value) { return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) } continue } dval := e.args[k] // convert to driver converter darg, err := driver.DefaultParameterConverter.ConvertValue(dval) if err != nil { return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) } if !driver.IsValue(darg) { return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) } if !reflect.DeepEqual(darg, v.Value) { return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) } } return nil } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/expectations_go18.go ================================================ // +build go1.8 package sqlmock import ( "database/sql" "database/sql/driver" "fmt" "reflect" ) // WillReturnRows specifies the set of resulting rows that will be returned // by the triggered query func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { sets := make([]*Rows, len(rows)) for i, r := range rows { sets[i] = r } e.rows = &rowSets{sets: sets} return e } func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { return nil } if len(args) != len(e.args) { return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) } // @TODO should we assert either all args are named or ordinal? for k, v := range args { // custom argument matcher matcher, ok := e.args[k].(Argument) if ok { if !matcher.Match(v.Value) { return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) } continue } dval := e.args[k] if named, isNamed := dval.(sql.NamedArg); isNamed { dval = named.Value if v.Name != named.Name { return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name) } } else if k+1 != v.Ordinal { return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal) } // convert to driver converter darg, err := driver.DefaultParameterConverter.ConvertValue(dval) if err != nil { return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err) } if !driver.IsValue(darg) { return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg) } if !reflect.DeepEqual(darg, v.Value) { return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value) } } return nil } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/result.go ================================================ package sqlmock import ( "database/sql/driver" ) // Result satisfies sql driver Result, which // holds last insert id and rows affected // by Exec queries type result struct { insertID int64 rowsAffected int64 err error } // NewResult creates a new sql driver Result // for Exec based query mocks. func NewResult(lastInsertID int64, rowsAffected int64) driver.Result { return &result{ insertID: lastInsertID, rowsAffected: rowsAffected, } } // NewErrorResult creates a new sql driver Result // which returns an error given for both interface methods func NewErrorResult(err error) driver.Result { return &result{ err: err, } } func (r *result) LastInsertId() (int64, error) { return r.insertID, r.err } func (r *result) RowsAffected() (int64, error) { return r.rowsAffected, r.err } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/rows.go ================================================ package sqlmock import ( "database/sql/driver" "encoding/csv" "fmt" "io" "strings" ) // CSVColumnParser is a function which converts trimmed csv // column string to a []byte representation. currently // transforms NULL to nil var CSVColumnParser = func(s string) []byte { switch { case strings.ToLower(s) == "null": return nil } return []byte(s) } type rowSets struct { sets []*Rows pos int } func (rs *rowSets) Columns() []string { return rs.sets[rs.pos].cols } func (rs *rowSets) Close() error { return rs.sets[rs.pos].closeErr } // advances to next row func (rs *rowSets) Next(dest []driver.Value) error { r := rs.sets[rs.pos] r.pos++ if r.pos > len(r.rows) { return io.EOF // per interface spec } for i, col := range r.rows[r.pos-1] { dest[i] = col } return r.nextErr[r.pos-1] } // transforms to debuggable printable string func (rs *rowSets) String() string { if rs.empty() { return "with empty rows" } msg := "should return rows:\n" if len(rs.sets) == 1 { for n, row := range rs.sets[0].rows { msg += fmt.Sprintf(" row %d - %+v\n", n, row) } return strings.TrimSpace(msg) } for i, set := range rs.sets { msg += fmt.Sprintf(" result set: %d\n", i) for n, row := range set.rows { msg += fmt.Sprintf(" row %d - %+v\n", n, row) } } return strings.TrimSpace(msg) } func (rs *rowSets) empty() bool { for _, set := range rs.sets { if len(set.rows) > 0 { return false } } return true } // Rows is a mocked collection of rows to // return for Query result type Rows struct { cols []string rows [][]driver.Value pos int nextErr map[int]error closeErr error } // NewRows allows Rows to be created from a // sql driver.Value slice or from the CSV string and // to be used as sql driver.Rows func NewRows(columns []string) *Rows { return &Rows{cols: columns, nextErr: make(map[int]error)} } // CloseError allows to set an error // which will be returned by rows.Close // function. // // The close error will be triggered only in cases // when rows.Next() EOF was not yet reached, that is // a default sql library behavior func (r *Rows) CloseError(err error) *Rows { r.closeErr = err return r } // RowError allows to set an error // which will be returned when a given // row number is read func (r *Rows) RowError(row int, err error) *Rows { r.nextErr[row] = err return r } // AddRow composed from database driver.Value slice // return the same instance to perform subsequent actions. // Note that the number of values must match the number // of columns func (r *Rows) AddRow(values ...driver.Value) *Rows { if len(values) != len(r.cols) { panic("Expected number of values to match number of columns") } row := make([]driver.Value, len(r.cols)) for i, v := range values { row[i] = v } r.rows = append(r.rows, row) return r } // FromCSVString build rows from csv string. // return the same instance to perform subsequent actions. // Note that the number of values must match the number // of columns func (r *Rows) FromCSVString(s string) *Rows { res := strings.NewReader(strings.TrimSpace(s)) csvReader := csv.NewReader(res) for { res, err := csvReader.Read() if err != nil || res == nil { break } row := make([]driver.Value, len(r.cols)) for i, v := range res { row[i] = CSVColumnParser(strings.TrimSpace(v)) } r.rows = append(r.rows, row) } return r } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/rows_go18.go ================================================ // +build go1.8 package sqlmock import "io" // Implement the "RowsNextResultSet" interface func (rs *rowSets) HasNextResultSet() bool { return rs.pos+1 < len(rs.sets) } // Implement the "RowsNextResultSet" interface func (rs *rowSets) NextResultSet() error { if !rs.HasNextResultSet() { return io.EOF } rs.pos++ return nil } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/sqlmock.go ================================================ /* Package sqlmock is a mock library implementing sql driver. Which has one and only purpose - to simulate any sql driver behavior in tests, without needing a real database connection. It helps to maintain correct **TDD** workflow. It does not require any modifications to your source code in order to test and mock database operations. Supports concurrency and multiple database mocking. The driver allows to mock any sql driver method behavior. */ package sqlmock import ( "database/sql" "database/sql/driver" "fmt" "regexp" "time" ) // Sqlmock interface serves to create expectations // for any kind of database action in order to mock // and test real database behavior. type Sqlmock interface { // ExpectClose queues an expectation for this database // action to be triggered. the *ExpectedClose allows // to mock database response ExpectClose() *ExpectedClose // ExpectationsWereMet checks whether all queued expectations // were met in order. If any of them was not met - an error is returned. ExpectationsWereMet() error // ExpectPrepare expects Prepare() to be called with sql query // which match sqlRegexStr given regexp. // the *ExpectedPrepare allows to mock database response. // Note that you may expect Query() or Exec() on the *ExpectedPrepare // statement to prevent repeating sqlRegexStr ExpectPrepare(sqlRegexStr string) *ExpectedPrepare // ExpectQuery expects Query() or QueryRow() to be called with sql query // which match sqlRegexStr given regexp. // the *ExpectedQuery allows to mock database response. ExpectQuery(sqlRegexStr string) *ExpectedQuery // ExpectExec expects Exec() to be called with sql query // which match sqlRegexStr given regexp. // the *ExpectedExec allows to mock database response ExpectExec(sqlRegexStr string) *ExpectedExec // ExpectBegin expects *sql.DB.Begin to be called. // the *ExpectedBegin allows to mock database response ExpectBegin() *ExpectedBegin // ExpectCommit expects *sql.Tx.Commit to be called. // the *ExpectedCommit allows to mock database response ExpectCommit() *ExpectedCommit // ExpectRollback expects *sql.Tx.Rollback to be called. // the *ExpectedRollback allows to mock database response ExpectRollback() *ExpectedRollback // MatchExpectationsInOrder gives an option whether to match all // expectations in the order they were set or not. // // By default it is set to - true. But if you use goroutines // to parallelize your query executation, that option may // be handy. // // This option may be turned on anytime during tests. As soon // as it is switched to false, expectations will be matched // in any order. Or otherwise if switched to true, any unmatched // expectations will be expected in order MatchExpectationsInOrder(bool) } type sqlmock struct { ordered bool dsn string opened int drv *mockDriver expected []expectation } func (c *sqlmock) open() (*sql.DB, Sqlmock, error) { db, err := sql.Open("sqlmock", c.dsn) if err != nil { return db, c, err } return db, c, db.Ping() } func (c *sqlmock) ExpectClose() *ExpectedClose { e := &ExpectedClose{} c.expected = append(c.expected, e) return e } func (c *sqlmock) MatchExpectationsInOrder(b bool) { c.ordered = b } // Close a mock database driver connection. It may or may not // be called depending on the sircumstances, but if it is called // there must be an *ExpectedClose expectation satisfied. // meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Close() error { c.drv.Lock() defer c.drv.Unlock() c.opened-- if c.opened == 0 { delete(c.drv.conns, c.dsn) } var expected *ExpectedClose var fulfilled int var ok bool for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if expected, ok = next.(*ExpectedClose); ok { break } next.Unlock() if c.ordered { return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next) } } if expected == nil { msg := "call to database Close was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return fmt.Errorf(msg) } expected.triggered = true expected.Unlock() return expected.err } func (c *sqlmock) ExpectationsWereMet() error { for _, e := range c.expected { if !e.fulfilled() { return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) } // for expected prepared statement check whether it was closed if expected if prep, ok := e.(*ExpectedPrepare); ok { if prep.mustBeClosed && !prep.wasClosed { return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep) } } } return nil } // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Begin() (driver.Tx, error) { ex, err := c.begin() if ex != nil { time.Sleep(ex.delay) } if err != nil { return nil, err } return c, nil } func (c *sqlmock) begin() (*ExpectedBegin, error) { var expected *ExpectedBegin var ok bool var fulfilled int for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if expected, ok = next.(*ExpectedBegin); ok { break } next.Unlock() if c.ordered { return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next) } } if expected == nil { msg := "call to database transaction Begin was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return nil, fmt.Errorf(msg) } expected.triggered = true expected.Unlock() return expected, expected.err } func (c *sqlmock) ExpectBegin() *ExpectedBegin { e := &ExpectedBegin{} c.expected = append(c.expected, e) return e } // Exec meets http://golang.org/pkg/database/sql/driver/#Execer func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { namedArgs := make([]namedValue, len(args)) for i, v := range args { namedArgs[i] = namedValue{ Ordinal: i + 1, Value: v, } } ex, err := c.exec(query, namedArgs) if ex != nil { time.Sleep(ex.delay) } if err != nil { return nil, err } return ex.result, nil } func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { query = stripQuery(query) var expected *ExpectedExec var fulfilled int var ok bool for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if c.ordered { if expected, ok = next.(*ExpectedExec); ok { break } next.Unlock() return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) } if exec, ok := next.(*ExpectedExec); ok { if err := exec.attemptMatch(query, args); err == nil { expected = exec break } } next.Unlock() } if expected == nil { msg := "call to ExecQuery '%s' with args %+v was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return nil, fmt.Errorf(msg, query, args) } defer expected.Unlock() if !expected.queryMatches(query) { return nil, fmt.Errorf("ExecQuery '%s', does not match regex '%s'", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) } expected.triggered = true if expected.err != nil { return expected, expected.err // mocked to return error } if expected.result == nil { return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) } return expected, nil } func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { e := &ExpectedExec{} sqlRegexStr = stripQuery(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr) c.expected = append(c.expected, e) return e } // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { ex, err := c.prepare(query) if ex != nil { time.Sleep(ex.delay) } if err != nil { return nil, err } return &statement{c, ex, query}, nil } func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { var expected *ExpectedPrepare var fulfilled int var ok bool query = stripQuery(query) for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if c.ordered { if expected, ok = next.(*ExpectedPrepare); ok { break } next.Unlock() return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next) } if pr, ok := next.(*ExpectedPrepare); ok { if pr.sqlRegex.MatchString(query) { expected = pr break } } next.Unlock() } if expected == nil { msg := "call to Prepare '%s' query was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return nil, fmt.Errorf(msg, query) } defer expected.Unlock() if !expected.sqlRegex.MatchString(query) { return nil, fmt.Errorf("Prepare query string '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } expected.triggered = true return expected, expected.err } func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { sqlRegexStr = stripQuery(sqlRegexStr) e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c} c.expected = append(c.expected, e) return e } type namedValue struct { Name string Ordinal int Value driver.Value } // Query meets http://golang.org/pkg/database/sql/driver/#Queryer func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { namedArgs := make([]namedValue, len(args)) for i, v := range args { namedArgs[i] = namedValue{ Ordinal: i + 1, Value: v, } } ex, err := c.query(query, namedArgs) if ex != nil { time.Sleep(ex.delay) } if err != nil { return nil, err } return ex.rows, nil } func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { query = stripQuery(query) var expected *ExpectedQuery var fulfilled int var ok bool for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if c.ordered { if expected, ok = next.(*ExpectedQuery); ok { break } next.Unlock() return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) } if qr, ok := next.(*ExpectedQuery); ok { if err := qr.attemptMatch(query, args); err == nil { expected = qr break } } next.Unlock() } if expected == nil { msg := "call to Query '%s' with args %+v was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return nil, fmt.Errorf(msg, query, args) } defer expected.Unlock() if !expected.queryMatches(query) { return nil, fmt.Errorf("Query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) } if err := expected.argsMatches(args); err != nil { return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) } expected.triggered = true if expected.err != nil { return expected, expected.err // mocked to return error } if expected.rows == nil { return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) } return expected, nil } func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { e := &ExpectedQuery{} sqlRegexStr = stripQuery(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr) c.expected = append(c.expected, e) return e } func (c *sqlmock) ExpectCommit() *ExpectedCommit { e := &ExpectedCommit{} c.expected = append(c.expected, e) return e } func (c *sqlmock) ExpectRollback() *ExpectedRollback { e := &ExpectedRollback{} c.expected = append(c.expected, e) return e } // Commit meets http://golang.org/pkg/database/sql/driver/#Tx func (c *sqlmock) Commit() error { var expected *ExpectedCommit var fulfilled int var ok bool for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if expected, ok = next.(*ExpectedCommit); ok { break } next.Unlock() if c.ordered { return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next) } } if expected == nil { msg := "call to Commit transaction was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return fmt.Errorf(msg) } expected.triggered = true expected.Unlock() return expected.err } // Rollback meets http://golang.org/pkg/database/sql/driver/#Tx func (c *sqlmock) Rollback() error { var expected *ExpectedRollback var fulfilled int var ok bool for _, next := range c.expected { next.Lock() if next.fulfilled() { next.Unlock() fulfilled++ continue } if expected, ok = next.(*ExpectedRollback); ok { break } next.Unlock() if c.ordered { return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next) } } if expected == nil { msg := "call to Rollback transaction was not expected" if fulfilled == len(c.expected) { msg = "all expectations were already fulfilled, " + msg } return fmt.Errorf(msg) } expected.triggered = true expected.Unlock() return expected.err } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/sqlmock_go18.go ================================================ // +build go1.8 package sqlmock import ( "context" "database/sql/driver" "errors" "time" ) var ErrCancelled = errors.New("canceling query due to user request") // Implement the "QueryerContext" interface func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { namedArgs := make([]namedValue, len(args)) for i, nv := range args { namedArgs[i] = namedValue(nv) } ex, err := c.query(query, namedArgs) if ex != nil { select { case <-time.After(ex.delay): if err != nil { return nil, err } return ex.rows, nil case <-ctx.Done(): return nil, ErrCancelled } } return nil, err } // Implement the "ExecerContext" interface func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { namedArgs := make([]namedValue, len(args)) for i, nv := range args { namedArgs[i] = namedValue(nv) } ex, err := c.exec(query, namedArgs) if ex != nil { select { case <-time.After(ex.delay): if err != nil { return nil, err } return ex.result, nil case <-ctx.Done(): return nil, ErrCancelled } } return nil, err } // Implement the "ConnBeginTx" interface func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { ex, err := c.begin() if ex != nil { select { case <-time.After(ex.delay): if err != nil { return nil, err } return c, nil case <-ctx.Done(): return nil, ErrCancelled } } return nil, err } // Implement the "ConnPrepareContext" interface func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { ex, err := c.prepare(query) if ex != nil { select { case <-time.After(ex.delay): if err != nil { return nil, err } return &statement{c, ex, query}, nil case <-ctx.Done(): return nil, ErrCancelled } } return nil, err } // Implement the "Pinger" interface // for now we do not have a Ping expectation // may be something for the future func (c *sqlmock) Ping(ctx context.Context) error { return nil } // Implement the "StmtExecContext" interface func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { return stmt.conn.ExecContext(ctx, stmt.query, args) } // Implement the "StmtQueryContext" interface func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { return stmt.conn.QueryContext(ctx, stmt.query, args) } // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/statement.go ================================================ package sqlmock import ( "database/sql/driver" ) type statement struct { conn *sqlmock ex *ExpectedPrepare query string } func (stmt *statement) Close() error { stmt.ex.wasClosed = true return stmt.ex.closeErr } func (stmt *statement) NumInput() int { return -1 } func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { return stmt.conn.Exec(stmt.query, args) } func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { return stmt.conn.Query(stmt.query, args) } ================================================ FILE: vendor/github.com/DATA-DOG/go-sqlmock/util.go ================================================ package sqlmock import ( "regexp" "strings" ) var re = regexp.MustCompile("\\s+") // strip out new lines and trim spaces func stripQuery(q string) (s string) { return strings.TrimSpace(re.ReplaceAllString(q, " ")) } ================================================ FILE: vendor/github.com/davecgh/go-spew/LICENSE ================================================ ISC License Copyright (c) 2012-2013 Dave Collins Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/bypass.go ================================================ // Copyright (c) 2015 Dave Collins // // Permission to use, copy, modify, and distribute this software for any // purpose with or without fee is hereby granted, provided that the above // copyright notice and this permission notice appear in all copies. // // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // NOTE: Due to the following build constraints, this file will only be compiled // when the code is not running on Google App Engine, compiled by GopherJS, and // "-tags safe" is not added to the go build command line. The "disableunsafe" // tag is deprecated and thus should not be used. // +build !js,!appengine,!safe,!disableunsafe package spew import ( "reflect" "unsafe" ) const ( // UnsafeDisabled is a build-time constant which specifies whether or // not access to the unsafe package is available. UnsafeDisabled = false // ptrSize is the size of a pointer on the current arch. ptrSize = unsafe.Sizeof((*byte)(nil)) ) var ( // offsetPtr, offsetScalar, and offsetFlag are the offsets for the // internal reflect.Value fields. These values are valid before golang // commit ecccf07e7f9d which changed the format. The are also valid // after commit 82f48826c6c7 which changed the format again to mirror // the original format. Code in the init function updates these offsets // as necessary. offsetPtr = uintptr(ptrSize) offsetScalar = uintptr(0) offsetFlag = uintptr(ptrSize * 2) // flagKindWidth and flagKindShift indicate various bits that the // reflect package uses internally to track kind information. // // flagRO indicates whether or not the value field of a reflect.Value is // read-only. // // flagIndir indicates whether the value field of a reflect.Value is // the actual data or a pointer to the data. // // These values are valid before golang commit 90a7c3c86944 which // changed their positions. Code in the init function updates these // flags as necessary. flagKindWidth = uintptr(5) flagKindShift = uintptr(flagKindWidth - 1) flagRO = uintptr(1 << 0) flagIndir = uintptr(1 << 1) ) func init() { // Older versions of reflect.Value stored small integers directly in the // ptr field (which is named val in the older versions). Versions // between commits ecccf07e7f9d and 82f48826c6c7 added a new field named // scalar for this purpose which unfortunately came before the flag // field, so the offset of the flag field is different for those // versions. // // This code constructs a new reflect.Value from a known small integer // and checks if the size of the reflect.Value struct indicates it has // the scalar field. When it does, the offsets are updated accordingly. vv := reflect.ValueOf(0xf00) if unsafe.Sizeof(vv) == (ptrSize * 4) { offsetScalar = ptrSize * 2 offsetFlag = ptrSize * 3 } // Commit 90a7c3c86944 changed the flag positions such that the low // order bits are the kind. This code extracts the kind from the flags // field and ensures it's the correct type. When it's not, the flag // order has been changed to the newer format, so the flags are updated // accordingly. upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag) upfv := *(*uintptr)(upf) flagKindMask := uintptr((1<>flagKindShift != uintptr(reflect.Int) { flagKindShift = 0 flagRO = 1 << 5 flagIndir = 1 << 6 // Commit adf9b30e5594 modified the flags to separate the // flagRO flag into two bits which specifies whether or not the // field is embedded. This causes flagIndir to move over a bit // and means that flagRO is the combination of either of the // original flagRO bit and the new bit. // // This code detects the change by extracting what used to be // the indirect bit to ensure it's set. When it's not, the flag // order has been changed to the newer format, so the flags are // updated accordingly. if upfv&flagIndir == 0 { flagRO = 3 << 5 flagIndir = 1 << 7 } } } // unsafeReflectValue converts the passed reflect.Value into a one that bypasses // the typical safety restrictions preventing access to unaddressable and // unexported data. It works by digging the raw pointer to the underlying // value out of the protected value and generating a new unprotected (unsafe) // reflect.Value to it. // // This allows us to check for implementations of the Stringer and error // interfaces to be used for pretty printing ordinarily unaddressable and // inaccessible values such as unexported struct fields. func unsafeReflectValue(v reflect.Value) (rv reflect.Value) { indirects := 1 vt := v.Type() upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr) rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag)) if rvf&flagIndir != 0 { vt = reflect.PtrTo(v.Type()) indirects++ } else if offsetScalar != 0 { // The value is in the scalar field when it's not one of the // reference types. switch vt.Kind() { case reflect.Uintptr: case reflect.Chan: case reflect.Func: case reflect.Map: case reflect.Ptr: case reflect.UnsafePointer: default: upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetScalar) } } pv := reflect.NewAt(vt, upv) rv = pv for i := 0; i < indirects; i++ { rv = rv.Elem() } return rv } ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/bypasssafe.go ================================================ // Copyright (c) 2015 Dave Collins // // Permission to use, copy, modify, and distribute this software for any // purpose with or without fee is hereby granted, provided that the above // copyright notice and this permission notice appear in all copies. // // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR // ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // NOTE: Due to the following build constraints, this file will only be compiled // when the code is running on Google App Engine, compiled by GopherJS, or // "-tags safe" is added to the go build command line. The "disableunsafe" // tag is deprecated and thus should not be used. // +build js appengine safe disableunsafe package spew import "reflect" const ( // UnsafeDisabled is a build-time constant which specifies whether or // not access to the unsafe package is available. UnsafeDisabled = true ) // unsafeReflectValue typically converts the passed reflect.Value into a one // that bypasses the typical safety restrictions preventing access to // unaddressable and unexported data. However, doing this relies on access to // the unsafe package. This is a stub version which simply returns the passed // reflect.Value when the unsafe package is not available. func unsafeReflectValue(v reflect.Value) reflect.Value { return v } ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/common.go ================================================ /* * Copyright (c) 2013 Dave Collins * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ package spew import ( "bytes" "fmt" "io" "reflect" "sort" "strconv" ) // Some constants in the form of bytes to avoid string overhead. This mirrors // the technique used in the fmt package. var ( panicBytes = []byte("(PANIC=") plusBytes = []byte("+") iBytes = []byte("i") trueBytes = []byte("true") falseBytes = []byte("false") interfaceBytes = []byte("(interface {})") commaNewlineBytes = []byte(",\n") newlineBytes = []byte("\n") openBraceBytes = []byte("{") openBraceNewlineBytes = []byte("{\n") closeBraceBytes = []byte("}") asteriskBytes = []byte("*") colonBytes = []byte(":") colonSpaceBytes = []byte(": ") openParenBytes = []byte("(") closeParenBytes = []byte(")") spaceBytes = []byte(" ") pointerChainBytes = []byte("->") nilAngleBytes = []byte("") maxNewlineBytes = []byte("\n") maxShortBytes = []byte("") circularBytes = []byte("") circularShortBytes = []byte("") invalidAngleBytes = []byte("") openBracketBytes = []byte("[") closeBracketBytes = []byte("]") percentBytes = []byte("%") precisionBytes = []byte(".") openAngleBytes = []byte("<") closeAngleBytes = []byte(">") openMapBytes = []byte("map[") closeMapBytes = []byte("]") lenEqualsBytes = []byte("len=") capEqualsBytes = []byte("cap=") ) // hexDigits is used to map a decimal value to a hex digit. var hexDigits = "0123456789abcdef" // catchPanic handles any panics that might occur during the handleMethods // calls. func catchPanic(w io.Writer, v reflect.Value) { if err := recover(); err != nil { w.Write(panicBytes) fmt.Fprintf(w, "%v", err) w.Write(closeParenBytes) } } // handleMethods attempts to call the Error and String methods on the underlying // type the passed reflect.Value represents and outputes the result to Writer w. // // It handles panics in any called methods by catching and displaying the error // as the formatted value. func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) { // We need an interface to check if the type implements the error or // Stringer interface. However, the reflect package won't give us an // interface on certain things like unexported struct fields in order // to enforce visibility rules. We use unsafe, when it's available, // to bypass these restrictions since this package does not mutate the // values. if !v.CanInterface() { if UnsafeDisabled { return false } v = unsafeReflectValue(v) } // Choose whether or not to do error and Stringer interface lookups against // the base type or a pointer to the base type depending on settings. // Technically calling one of these methods with a pointer receiver can // mutate the value, however, types which choose to satisify an error or // Stringer interface with a pointer receiver should not be mutating their // state inside these interface methods. if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() { v = unsafeReflectValue(v) } if v.CanAddr() { v = v.Addr() } // Is it an error or Stringer? switch iface := v.Interface().(type) { case error: defer catchPanic(w, v) if cs.ContinueOnMethod { w.Write(openParenBytes) w.Write([]byte(iface.Error())) w.Write(closeParenBytes) w.Write(spaceBytes) return false } w.Write([]byte(iface.Error())) return true case fmt.Stringer: defer catchPanic(w, v) if cs.ContinueOnMethod { w.Write(openParenBytes) w.Write([]byte(iface.String())) w.Write(closeParenBytes) w.Write(spaceBytes) return false } w.Write([]byte(iface.String())) return true } return false } // printBool outputs a boolean value as true or false to Writer w. func printBool(w io.Writer, val bool) { if val { w.Write(trueBytes) } else { w.Write(falseBytes) } } // printInt outputs a signed integer value to Writer w. func printInt(w io.Writer, val int64, base int) { w.Write([]byte(strconv.FormatInt(val, base))) } // printUint outputs an unsigned integer value to Writer w. func printUint(w io.Writer, val uint64, base int) { w.Write([]byte(strconv.FormatUint(val, base))) } // printFloat outputs a floating point value using the specified precision, // which is expected to be 32 or 64bit, to Writer w. func printFloat(w io.Writer, val float64, precision int) { w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision))) } // printComplex outputs a complex value using the specified float precision // for the real and imaginary parts to Writer w. func printComplex(w io.Writer, c complex128, floatPrecision int) { r := real(c) w.Write(openParenBytes) w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision))) i := imag(c) if i >= 0 { w.Write(plusBytes) } w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision))) w.Write(iBytes) w.Write(closeParenBytes) } // printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x' // prefix to Writer w. func printHexPtr(w io.Writer, p uintptr) { // Null pointer. num := uint64(p) if num == 0 { w.Write(nilAngleBytes) return } // Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix buf := make([]byte, 18) // It's simpler to construct the hex string right to left. base := uint64(16) i := len(buf) - 1 for num >= base { buf[i] = hexDigits[num%base] num /= base i-- } buf[i] = hexDigits[num] // Add '0x' prefix. i-- buf[i] = 'x' i-- buf[i] = '0' // Strip unused leading bytes. buf = buf[i:] w.Write(buf) } // valuesSorter implements sort.Interface to allow a slice of reflect.Value // elements to be sorted. type valuesSorter struct { values []reflect.Value strings []string // either nil or same len and values cs *ConfigState } // newValuesSorter initializes a valuesSorter instance, which holds a set of // surrogate keys on which the data should be sorted. It uses flags in // ConfigState to decide if and how to populate those surrogate keys. func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface { vs := &valuesSorter{values: values, cs: cs} if canSortSimply(vs.values[0].Kind()) { return vs } if !cs.DisableMethods { vs.strings = make([]string, len(values)) for i := range vs.values { b := bytes.Buffer{} if !handleMethods(cs, &b, vs.values[i]) { vs.strings = nil break } vs.strings[i] = b.String() } } if vs.strings == nil && cs.SpewKeys { vs.strings = make([]string, len(values)) for i := range vs.values { vs.strings[i] = Sprintf("%#v", vs.values[i].Interface()) } } return vs } // canSortSimply tests whether a reflect.Kind is a primitive that can be sorted // directly, or whether it should be considered for sorting by surrogate keys // (if the ConfigState allows it). func canSortSimply(kind reflect.Kind) bool { // This switch parallels valueSortLess, except for the default case. switch kind { case reflect.Bool: return true case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: return true case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: return true case reflect.Float32, reflect.Float64: return true case reflect.String: return true case reflect.Uintptr: return true case reflect.Array: return true } return false } // Len returns the number of values in the slice. It is part of the // sort.Interface implementation. func (s *valuesSorter) Len() int { return len(s.values) } // Swap swaps the values at the passed indices. It is part of the // sort.Interface implementation. func (s *valuesSorter) Swap(i, j int) { s.values[i], s.values[j] = s.values[j], s.values[i] if s.strings != nil { s.strings[i], s.strings[j] = s.strings[j], s.strings[i] } } // valueSortLess returns whether the first value should sort before the second // value. It is used by valueSorter.Less as part of the sort.Interface // implementation. func valueSortLess(a, b reflect.Value) bool { switch a.Kind() { case reflect.Bool: return !a.Bool() && b.Bool() case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: return a.Int() < b.Int() case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: return a.Uint() < b.Uint() case reflect.Float32, reflect.Float64: return a.Float() < b.Float() case reflect.String: return a.String() < b.String() case reflect.Uintptr: return a.Uint() < b.Uint() case reflect.Array: // Compare the contents of both arrays. l := a.Len() for i := 0; i < l; i++ { av := a.Index(i) bv := b.Index(i) if av.Interface() == bv.Interface() { continue } return valueSortLess(av, bv) } } return a.String() < b.String() } // Less returns whether the value at index i should sort before the // value at index j. It is part of the sort.Interface implementation. func (s *valuesSorter) Less(i, j int) bool { if s.strings == nil { return valueSortLess(s.values[i], s.values[j]) } return s.strings[i] < s.strings[j] } // sortValues is a sort function that handles both native types and any type that // can be converted to error or Stringer. Other inputs are sorted according to // their Value.String() value to ensure display stability. func sortValues(values []reflect.Value, cs *ConfigState) { if len(values) == 0 { return } sort.Sort(newValuesSorter(values, cs)) } ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/config.go ================================================ /* * Copyright (c) 2013 Dave Collins * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ package spew import ( "bytes" "fmt" "io" "os" ) // ConfigState houses the configuration options used by spew to format and // display values. There is a global instance, Config, that is used to control // all top-level Formatter and Dump functionality. Each ConfigState instance // provides methods equivalent to the top-level functions. // // The zero value for ConfigState provides no indentation. You would typically // want to set it to a space or a tab. // // Alternatively, you can use NewDefaultConfig to get a ConfigState instance // with default settings. See the documentation of NewDefaultConfig for default // values. type ConfigState struct { // Indent specifies the string to use for each indentation level. The // global config instance that all top-level functions use set this to a // single space by default. If you would like more indentation, you might // set this to a tab with "\t" or perhaps two spaces with " ". Indent string // MaxDepth controls the maximum number of levels to descend into nested // data structures. The default, 0, means there is no limit. // // NOTE: Circular data structures are properly detected, so it is not // necessary to set this value unless you specifically want to limit deeply // nested data structures. MaxDepth int // DisableMethods specifies whether or not error and Stringer interfaces are // invoked for types that implement them. DisableMethods bool // DisablePointerMethods specifies whether or not to check for and invoke // error and Stringer interfaces on types which only accept a pointer // receiver when the current type is not a pointer. // // NOTE: This might be an unsafe action since calling one of these methods // with a pointer receiver could technically mutate the value, however, // in practice, types which choose to satisify an error or Stringer // interface with a pointer receiver should not be mutating their state // inside these interface methods. As a result, this option relies on // access to the unsafe package, so it will not have any effect when // running in environments without access to the unsafe package such as // Google App Engine or with the "safe" build tag specified. DisablePointerMethods bool // DisablePointerAddresses specifies whether to disable the printing of // pointer addresses. This is useful when diffing data structures in tests. DisablePointerAddresses bool // DisableCapacities specifies whether to disable the printing of capacities // for arrays, slices, maps and channels. This is useful when diffing // data structures in tests. DisableCapacities bool // ContinueOnMethod specifies whether or not recursion should continue once // a custom error or Stringer interface is invoked. The default, false, // means it will print the results of invoking the custom error or Stringer // interface and return immediately instead of continuing to recurse into // the internals of the data type. // // NOTE: This flag does not have any effect if method invocation is disabled // via the DisableMethods or DisablePointerMethods options. ContinueOnMethod bool // SortKeys specifies map keys should be sorted before being printed. Use // this to have a more deterministic, diffable output. Note that only // native types (bool, int, uint, floats, uintptr and string) and types // that support the error or Stringer interfaces (if methods are // enabled) are supported, with other types sorted according to the // reflect.Value.String() output which guarantees display stability. SortKeys bool // SpewKeys specifies that, as a last resort attempt, map keys should // be spewed to strings and sorted by those strings. This is only // considered if SortKeys is true. SpewKeys bool } // Config is the active configuration of the top-level functions. // The configuration can be changed by modifying the contents of spew.Config. var Config = ConfigState{Indent: " "} // Errorf is a wrapper for fmt.Errorf that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the formatted string as a value that satisfies error. See NewFormatter // for formatting details. // // This function is shorthand for the following syntax: // // fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) { return fmt.Errorf(format, c.convertArgs(a)...) } // Fprint is a wrapper for fmt.Fprint that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) { return fmt.Fprint(w, c.convertArgs(a)...) } // Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { return fmt.Fprintf(w, format, c.convertArgs(a)...) } // Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it // passed with a Formatter interface returned by c.NewFormatter. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) { return fmt.Fprintln(w, c.convertArgs(a)...) } // Print is a wrapper for fmt.Print that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Print(c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Print(a ...interface{}) (n int, err error) { return fmt.Print(c.convertArgs(a)...) } // Printf is a wrapper for fmt.Printf that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) { return fmt.Printf(format, c.convertArgs(a)...) } // Println is a wrapper for fmt.Println that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Println(c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Println(a ...interface{}) (n int, err error) { return fmt.Println(c.convertArgs(a)...) } // Sprint is a wrapper for fmt.Sprint that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the resulting string. See NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Sprint(a ...interface{}) string { return fmt.Sprint(c.convertArgs(a)...) } // Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were // passed with a Formatter interface returned by c.NewFormatter. It returns // the resulting string. See NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Sprintf(format string, a ...interface{}) string { return fmt.Sprintf(format, c.convertArgs(a)...) } // Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it // were passed with a Formatter interface returned by c.NewFormatter. It // returns the resulting string. See NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b)) func (c *ConfigState) Sprintln(a ...interface{}) string { return fmt.Sprintln(c.convertArgs(a)...) } /* NewFormatter returns a custom formatter that satisfies the fmt.Formatter interface. As a result, it integrates cleanly with standard fmt package printing functions. The formatter is useful for inline printing of smaller data types similar to the standard %v format specifier. The custom formatter only responds to the %v (most compact), %+v (adds pointer addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb combinations. Any other verbs such as %x and %q will be sent to the the standard fmt package for formatting. In addition, the custom formatter ignores the width and precision arguments (however they will still work on the format specifiers not handled by the custom formatter). Typically this function shouldn't be called directly. It is much easier to make use of the custom formatter by calling one of the convenience functions such as c.Printf, c.Println, or c.Printf. */ func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter { return newFormatter(c, v) } // Fdump formats and displays the passed arguments to io.Writer w. It formats // exactly the same as Dump. func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) { fdump(c, w, a...) } /* Dump displays the passed parameters to standard out with newlines, customizable indentation, and additional debug information such as complete types and all pointer addresses used to indirect to the final value. It provides the following features over the built-in printing facilities provided by the fmt package: * Pointers are dereferenced and followed * Circular data structures are detected and handled properly * Custom Stringer/error interfaces are optionally invoked, including on unexported types * Custom types which only implement the Stringer/error interfaces via a pointer receiver are optionally invoked when passing non-pointer variables * Byte arrays and slices are dumped like the hexdump -C command which includes offsets, byte values in hex, and ASCII output The configuration options are controlled by modifying the public members of c. See ConfigState for options documentation. See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to get the formatted result as a string. */ func (c *ConfigState) Dump(a ...interface{}) { fdump(c, os.Stdout, a...) } // Sdump returns a string with the passed arguments formatted exactly the same // as Dump. func (c *ConfigState) Sdump(a ...interface{}) string { var buf bytes.Buffer fdump(c, &buf, a...) return buf.String() } // convertArgs accepts a slice of arguments and returns a slice of the same // length with each argument converted to a spew Formatter interface using // the ConfigState associated with s. func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) { formatters = make([]interface{}, len(args)) for index, arg := range args { formatters[index] = newFormatter(c, arg) } return formatters } // NewDefaultConfig returns a ConfigState with the following default settings. // // Indent: " " // MaxDepth: 0 // DisableMethods: false // DisablePointerMethods: false // ContinueOnMethod: false // SortKeys: false func NewDefaultConfig() *ConfigState { return &ConfigState{Indent: " "} } ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/doc.go ================================================ /* * Copyright (c) 2013 Dave Collins * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ /* Package spew implements a deep pretty printer for Go data structures to aid in debugging. A quick overview of the additional features spew provides over the built-in printing facilities for Go data types are as follows: * Pointers are dereferenced and followed * Circular data structures are detected and handled properly * Custom Stringer/error interfaces are optionally invoked, including on unexported types * Custom types which only implement the Stringer/error interfaces via a pointer receiver are optionally invoked when passing non-pointer variables * Byte arrays and slices are dumped like the hexdump -C command which includes offsets, byte values in hex, and ASCII output (only when using Dump style) There are two different approaches spew allows for dumping Go data structures: * Dump style which prints with newlines, customizable indentation, and additional debug information such as types and all pointer addresses used to indirect to the final value * A custom Formatter interface that integrates cleanly with the standard fmt package and replaces %v, %+v, %#v, and %#+v to provide inline printing similar to the default %v while providing the additional functionality outlined above and passing unsupported format verbs such as %x and %q along to fmt Quick Start This section demonstrates how to quickly get started with spew. See the sections below for further details on formatting and configuration options. To dump a variable with full newlines, indentation, type, and pointer information use Dump, Fdump, or Sdump: spew.Dump(myVar1, myVar2, ...) spew.Fdump(someWriter, myVar1, myVar2, ...) str := spew.Sdump(myVar1, myVar2, ...) Alternatively, if you would prefer to use format strings with a compacted inline printing style, use the convenience wrappers Printf, Fprintf, etc with %v (most compact), %+v (adds pointer addresses), %#v (adds types), or %#+v (adds types and pointer addresses): spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2) spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2) spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) Configuration Options Configuration of spew is handled by fields in the ConfigState type. For convenience, all of the top-level functions use a global state available via the spew.Config global. It is also possible to create a ConfigState instance that provides methods equivalent to the top-level functions. This allows concurrent configuration options. See the ConfigState documentation for more details. The following configuration options are available: * Indent String to use for each indentation level for Dump functions. It is a single space by default. A popular alternative is "\t". * MaxDepth Maximum number of levels to descend into nested data structures. There is no limit by default. * DisableMethods Disables invocation of error and Stringer interface methods. Method invocation is enabled by default. * DisablePointerMethods Disables invocation of error and Stringer interface methods on types which only accept pointer receivers from non-pointer variables. Pointer method invocation is enabled by default. * ContinueOnMethod Enables recursion into types after invoking error and Stringer interface methods. Recursion after method invocation is disabled by default. * SortKeys Specifies map keys should be sorted before being printed. Use this to have a more deterministic, diffable output. Note that only native types (bool, int, uint, floats, uintptr and string) and types which implement error or Stringer interfaces are supported with other types sorted according to the reflect.Value.String() output which guarantees display stability. Natural map order is used by default. * SpewKeys Specifies that, as a last resort attempt, map keys should be spewed to strings and sorted by those strings. This is only considered if SortKeys is true. Dump Usage Simply call spew.Dump with a list of variables you want to dump: spew.Dump(myVar1, myVar2, ...) You may also call spew.Fdump if you would prefer to output to an arbitrary io.Writer. For example, to dump to standard error: spew.Fdump(os.Stderr, myVar1, myVar2, ...) A third option is to call spew.Sdump to get the formatted output as a string: str := spew.Sdump(myVar1, myVar2, ...) Sample Dump Output See the Dump example for details on the setup of the types and variables being shown here. (main.Foo) { unexportedField: (*main.Bar)(0xf84002e210)({ flag: (main.Flag) flagTwo, data: (uintptr) }), ExportedField: (map[interface {}]interface {}) (len=1) { (string) (len=3) "one": (bool) true } } Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C command as shown. ([]uint8) (len=32 cap=32) { 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... | 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0| 00000020 31 32 |12| } Custom Formatter Spew provides a custom formatter that implements the fmt.Formatter interface so that it integrates cleanly with standard fmt package printing functions. The formatter is useful for inline printing of smaller data types similar to the standard %v format specifier. The custom formatter only responds to the %v (most compact), %+v (adds pointer addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb combinations. Any other verbs such as %x and %q will be sent to the the standard fmt package for formatting. In addition, the custom formatter ignores the width and precision arguments (however they will still work on the format specifiers not handled by the custom formatter). Custom Formatter Usage The simplest way to make use of the spew custom formatter is to call one of the convenience functions such as spew.Printf, spew.Println, or spew.Printf. The functions have syntax you are most likely already familiar with: spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2) spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) spew.Println(myVar, myVar2) spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2) spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4) See the Index for the full list convenience functions. Sample Formatter Output Double pointer to a uint8: %v: <**>5 %+v: <**>(0xf8400420d0->0xf8400420c8)5 %#v: (**uint8)5 %#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5 Pointer to circular struct with a uint8 field and a pointer to itself: %v: <*>{1 <*>} %+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)} %#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)} %#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)} See the Printf example for details on the setup of variables being shown here. Errors Since it is possible for custom Stringer/error interfaces to panic, spew detects them and handles them internally by printing the panic information inline with the output. Since spew is intended to provide deep pretty printing capabilities on structures, it intentionally does not return any errors. */ package spew ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/dump.go ================================================ /* * Copyright (c) 2013 Dave Collins * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ package spew import ( "bytes" "encoding/hex" "fmt" "io" "os" "reflect" "regexp" "strconv" "strings" ) var ( // uint8Type is a reflect.Type representing a uint8. It is used to // convert cgo types to uint8 slices for hexdumping. uint8Type = reflect.TypeOf(uint8(0)) // cCharRE is a regular expression that matches a cgo char. // It is used to detect character arrays to hexdump them. cCharRE = regexp.MustCompile("^.*\\._Ctype_char$") // cUnsignedCharRE is a regular expression that matches a cgo unsigned // char. It is used to detect unsigned character arrays to hexdump // them. cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$") // cUint8tCharRE is a regular expression that matches a cgo uint8_t. // It is used to detect uint8_t arrays to hexdump them. cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$") ) // dumpState contains information about the state of a dump operation. type dumpState struct { w io.Writer depth int pointers map[uintptr]int ignoreNextType bool ignoreNextIndent bool cs *ConfigState } // indent performs indentation according to the depth level and cs.Indent // option. func (d *dumpState) indent() { if d.ignoreNextIndent { d.ignoreNextIndent = false return } d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth)) } // unpackValue returns values inside of non-nil interfaces when possible. // This is useful for data types like structs, arrays, slices, and maps which // can contain varying types packed inside an interface. func (d *dumpState) unpackValue(v reflect.Value) reflect.Value { if v.Kind() == reflect.Interface && !v.IsNil() { v = v.Elem() } return v } // dumpPtr handles formatting of pointers by indirecting them as necessary. func (d *dumpState) dumpPtr(v reflect.Value) { // Remove pointers at or below the current depth from map used to detect // circular refs. for k, depth := range d.pointers { if depth >= d.depth { delete(d.pointers, k) } } // Keep list of all dereferenced pointers to show later. pointerChain := make([]uintptr, 0) // Figure out how many levels of indirection there are by dereferencing // pointers and unpacking interfaces down the chain while detecting circular // references. nilFound := false cycleFound := false indirects := 0 ve := v for ve.Kind() == reflect.Ptr { if ve.IsNil() { nilFound = true break } indirects++ addr := ve.Pointer() pointerChain = append(pointerChain, addr) if pd, ok := d.pointers[addr]; ok && pd < d.depth { cycleFound = true indirects-- break } d.pointers[addr] = d.depth ve = ve.Elem() if ve.Kind() == reflect.Interface { if ve.IsNil() { nilFound = true break } ve = ve.Elem() } } // Display type information. d.w.Write(openParenBytes) d.w.Write(bytes.Repeat(asteriskBytes, indirects)) d.w.Write([]byte(ve.Type().String())) d.w.Write(closeParenBytes) // Display pointer information. if !d.cs.DisablePointerAddresses && len(pointerChain) > 0 { d.w.Write(openParenBytes) for i, addr := range pointerChain { if i > 0 { d.w.Write(pointerChainBytes) } printHexPtr(d.w, addr) } d.w.Write(closeParenBytes) } // Display dereferenced value. d.w.Write(openParenBytes) switch { case nilFound == true: d.w.Write(nilAngleBytes) case cycleFound == true: d.w.Write(circularBytes) default: d.ignoreNextType = true d.dump(ve) } d.w.Write(closeParenBytes) } // dumpSlice handles formatting of arrays and slices. Byte (uint8 under // reflection) arrays and slices are dumped in hexdump -C fashion. func (d *dumpState) dumpSlice(v reflect.Value) { // Determine whether this type should be hex dumped or not. Also, // for types which should be hexdumped, try to use the underlying data // first, then fall back to trying to convert them to a uint8 slice. var buf []uint8 doConvert := false doHexDump := false numEntries := v.Len() if numEntries > 0 { vt := v.Index(0).Type() vts := vt.String() switch { // C types that need to be converted. case cCharRE.MatchString(vts): fallthrough case cUnsignedCharRE.MatchString(vts): fallthrough case cUint8tCharRE.MatchString(vts): doConvert = true // Try to use existing uint8 slices and fall back to converting // and copying if that fails. case vt.Kind() == reflect.Uint8: // We need an addressable interface to convert the type // to a byte slice. However, the reflect package won't // give us an interface on certain things like // unexported struct fields in order to enforce // visibility rules. We use unsafe, when available, to // bypass these restrictions since this package does not // mutate the values. vs := v if !vs.CanInterface() || !vs.CanAddr() { vs = unsafeReflectValue(vs) } if !UnsafeDisabled { vs = vs.Slice(0, numEntries) // Use the existing uint8 slice if it can be // type asserted. iface := vs.Interface() if slice, ok := iface.([]uint8); ok { buf = slice doHexDump = true break } } // The underlying data needs to be converted if it can't // be type asserted to a uint8 slice. doConvert = true } // Copy and convert the underlying type if needed. if doConvert && vt.ConvertibleTo(uint8Type) { // Convert and copy each element into a uint8 byte // slice. buf = make([]uint8, numEntries) for i := 0; i < numEntries; i++ { vv := v.Index(i) buf[i] = uint8(vv.Convert(uint8Type).Uint()) } doHexDump = true } } // Hexdump the entire slice as needed. if doHexDump { indent := strings.Repeat(d.cs.Indent, d.depth) str := indent + hex.Dump(buf) str = strings.Replace(str, "\n", "\n"+indent, -1) str = strings.TrimRight(str, d.cs.Indent) d.w.Write([]byte(str)) return } // Recursively call dump for each item. for i := 0; i < numEntries; i++ { d.dump(d.unpackValue(v.Index(i))) if i < (numEntries - 1) { d.w.Write(commaNewlineBytes) } else { d.w.Write(newlineBytes) } } } // dump is the main workhorse for dumping a value. It uses the passed reflect // value to figure out what kind of object we are dealing with and formats it // appropriately. It is a recursive function, however circular data structures // are detected and handled properly. func (d *dumpState) dump(v reflect.Value) { // Handle invalid reflect values immediately. kind := v.Kind() if kind == reflect.Invalid { d.w.Write(invalidAngleBytes) return } // Handle pointers specially. if kind == reflect.Ptr { d.indent() d.dumpPtr(v) return } // Print type information unless already handled elsewhere. if !d.ignoreNextType { d.indent() d.w.Write(openParenBytes) d.w.Write([]byte(v.Type().String())) d.w.Write(closeParenBytes) d.w.Write(spaceBytes) } d.ignoreNextType = false // Display length and capacity if the built-in len and cap functions // work with the value's kind and the len/cap itself is non-zero. valueLen, valueCap := 0, 0 switch v.Kind() { case reflect.Array, reflect.Slice, reflect.Chan: valueLen, valueCap = v.Len(), v.Cap() case reflect.Map, reflect.String: valueLen = v.Len() } if valueLen != 0 || !d.cs.DisableCapacities && valueCap != 0 { d.w.Write(openParenBytes) if valueLen != 0 { d.w.Write(lenEqualsBytes) printInt(d.w, int64(valueLen), 10) } if !d.cs.DisableCapacities && valueCap != 0 { if valueLen != 0 { d.w.Write(spaceBytes) } d.w.Write(capEqualsBytes) printInt(d.w, int64(valueCap), 10) } d.w.Write(closeParenBytes) d.w.Write(spaceBytes) } // Call Stringer/error interfaces if they exist and the handle methods flag // is enabled if !d.cs.DisableMethods { if (kind != reflect.Invalid) && (kind != reflect.Interface) { if handled := handleMethods(d.cs, d.w, v); handled { return } } } switch kind { case reflect.Invalid: // Do nothing. We should never get here since invalid has already // been handled above. case reflect.Bool: printBool(d.w, v.Bool()) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: printInt(d.w, v.Int(), 10) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: printUint(d.w, v.Uint(), 10) case reflect.Float32: printFloat(d.w, v.Float(), 32) case reflect.Float64: printFloat(d.w, v.Float(), 64) case reflect.Complex64: printComplex(d.w, v.Complex(), 32) case reflect.Complex128: printComplex(d.w, v.Complex(), 64) case reflect.Slice: if v.IsNil() { d.w.Write(nilAngleBytes) break } fallthrough case reflect.Array: d.w.Write(openBraceNewlineBytes) d.depth++ if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) { d.indent() d.w.Write(maxNewlineBytes) } else { d.dumpSlice(v) } d.depth-- d.indent() d.w.Write(closeBraceBytes) case reflect.String: d.w.Write([]byte(strconv.Quote(v.String()))) case reflect.Interface: // The only time we should get here is for nil interfaces due to // unpackValue calls. if v.IsNil() { d.w.Write(nilAngleBytes) } case reflect.Ptr: // Do nothing. We should never get here since pointers have already // been handled above. case reflect.Map: // nil maps should be indicated as different than empty maps if v.IsNil() { d.w.Write(nilAngleBytes) break } d.w.Write(openBraceNewlineBytes) d.depth++ if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) { d.indent() d.w.Write(maxNewlineBytes) } else { numEntries := v.Len() keys := v.MapKeys() if d.cs.SortKeys { sortValues(keys, d.cs) } for i, key := range keys { d.dump(d.unpackValue(key)) d.w.Write(colonSpaceBytes) d.ignoreNextIndent = true d.dump(d.unpackValue(v.MapIndex(key))) if i < (numEntries - 1) { d.w.Write(commaNewlineBytes) } else { d.w.Write(newlineBytes) } } } d.depth-- d.indent() d.w.Write(closeBraceBytes) case reflect.Struct: d.w.Write(openBraceNewlineBytes) d.depth++ if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) { d.indent() d.w.Write(maxNewlineBytes) } else { vt := v.Type() numFields := v.NumField() for i := 0; i < numFields; i++ { d.indent() vtf := vt.Field(i) d.w.Write([]byte(vtf.Name)) d.w.Write(colonSpaceBytes) d.ignoreNextIndent = true d.dump(d.unpackValue(v.Field(i))) if i < (numFields - 1) { d.w.Write(commaNewlineBytes) } else { d.w.Write(newlineBytes) } } } d.depth-- d.indent() d.w.Write(closeBraceBytes) case reflect.Uintptr: printHexPtr(d.w, uintptr(v.Uint())) case reflect.UnsafePointer, reflect.Chan, reflect.Func: printHexPtr(d.w, v.Pointer()) // There were not any other types at the time this code was written, but // fall back to letting the default fmt package handle it in case any new // types are added. default: if v.CanInterface() { fmt.Fprintf(d.w, "%v", v.Interface()) } else { fmt.Fprintf(d.w, "%v", v.String()) } } } // fdump is a helper function to consolidate the logic from the various public // methods which take varying writers and config states. func fdump(cs *ConfigState, w io.Writer, a ...interface{}) { for _, arg := range a { if arg == nil { w.Write(interfaceBytes) w.Write(spaceBytes) w.Write(nilAngleBytes) w.Write(newlineBytes) continue } d := dumpState{w: w, cs: cs} d.pointers = make(map[uintptr]int) d.dump(reflect.ValueOf(arg)) d.w.Write(newlineBytes) } } // Fdump formats and displays the passed arguments to io.Writer w. It formats // exactly the same as Dump. func Fdump(w io.Writer, a ...interface{}) { fdump(&Config, w, a...) } // Sdump returns a string with the passed arguments formatted exactly the same // as Dump. func Sdump(a ...interface{}) string { var buf bytes.Buffer fdump(&Config, &buf, a...) return buf.String() } /* Dump displays the passed parameters to standard out with newlines, customizable indentation, and additional debug information such as complete types and all pointer addresses used to indirect to the final value. It provides the following features over the built-in printing facilities provided by the fmt package: * Pointers are dereferenced and followed * Circular data structures are detected and handled properly * Custom Stringer/error interfaces are optionally invoked, including on unexported types * Custom types which only implement the Stringer/error interfaces via a pointer receiver are optionally invoked when passing non-pointer variables * Byte arrays and slices are dumped like the hexdump -C command which includes offsets, byte values in hex, and ASCII output The configuration options are controlled by an exported package global, spew.Config. See ConfigState for options documentation. See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to get the formatted result as a string. */ func Dump(a ...interface{}) { fdump(&Config, os.Stdout, a...) } ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/format.go ================================================ /* * Copyright (c) 2013 Dave Collins * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ package spew import ( "bytes" "fmt" "reflect" "strconv" "strings" ) // supportedFlags is a list of all the character flags supported by fmt package. const supportedFlags = "0-+# " // formatState implements the fmt.Formatter interface and contains information // about the state of a formatting operation. The NewFormatter function can // be used to get a new Formatter which can be used directly as arguments // in standard fmt package printing calls. type formatState struct { value interface{} fs fmt.State depth int pointers map[uintptr]int ignoreNextType bool cs *ConfigState } // buildDefaultFormat recreates the original format string without precision // and width information to pass in to fmt.Sprintf in the case of an // unrecognized type. Unless new types are added to the language, this // function won't ever be called. func (f *formatState) buildDefaultFormat() (format string) { buf := bytes.NewBuffer(percentBytes) for _, flag := range supportedFlags { if f.fs.Flag(int(flag)) { buf.WriteRune(flag) } } buf.WriteRune('v') format = buf.String() return format } // constructOrigFormat recreates the original format string including precision // and width information to pass along to the standard fmt package. This allows // automatic deferral of all format strings this package doesn't support. func (f *formatState) constructOrigFormat(verb rune) (format string) { buf := bytes.NewBuffer(percentBytes) for _, flag := range supportedFlags { if f.fs.Flag(int(flag)) { buf.WriteRune(flag) } } if width, ok := f.fs.Width(); ok { buf.WriteString(strconv.Itoa(width)) } if precision, ok := f.fs.Precision(); ok { buf.Write(precisionBytes) buf.WriteString(strconv.Itoa(precision)) } buf.WriteRune(verb) format = buf.String() return format } // unpackValue returns values inside of non-nil interfaces when possible and // ensures that types for values which have been unpacked from an interface // are displayed when the show types flag is also set. // This is useful for data types like structs, arrays, slices, and maps which // can contain varying types packed inside an interface. func (f *formatState) unpackValue(v reflect.Value) reflect.Value { if v.Kind() == reflect.Interface { f.ignoreNextType = false if !v.IsNil() { v = v.Elem() } } return v } // formatPtr handles formatting of pointers by indirecting them as necessary. func (f *formatState) formatPtr(v reflect.Value) { // Display nil if top level pointer is nil. showTypes := f.fs.Flag('#') if v.IsNil() && (!showTypes || f.ignoreNextType) { f.fs.Write(nilAngleBytes) return } // Remove pointers at or below the current depth from map used to detect // circular refs. for k, depth := range f.pointers { if depth >= f.depth { delete(f.pointers, k) } } // Keep list of all dereferenced pointers to possibly show later. pointerChain := make([]uintptr, 0) // Figure out how many levels of indirection there are by derferencing // pointers and unpacking interfaces down the chain while detecting circular // references. nilFound := false cycleFound := false indirects := 0 ve := v for ve.Kind() == reflect.Ptr { if ve.IsNil() { nilFound = true break } indirects++ addr := ve.Pointer() pointerChain = append(pointerChain, addr) if pd, ok := f.pointers[addr]; ok && pd < f.depth { cycleFound = true indirects-- break } f.pointers[addr] = f.depth ve = ve.Elem() if ve.Kind() == reflect.Interface { if ve.IsNil() { nilFound = true break } ve = ve.Elem() } } // Display type or indirection level depending on flags. if showTypes && !f.ignoreNextType { f.fs.Write(openParenBytes) f.fs.Write(bytes.Repeat(asteriskBytes, indirects)) f.fs.Write([]byte(ve.Type().String())) f.fs.Write(closeParenBytes) } else { if nilFound || cycleFound { indirects += strings.Count(ve.Type().String(), "*") } f.fs.Write(openAngleBytes) f.fs.Write([]byte(strings.Repeat("*", indirects))) f.fs.Write(closeAngleBytes) } // Display pointer information depending on flags. if f.fs.Flag('+') && (len(pointerChain) > 0) { f.fs.Write(openParenBytes) for i, addr := range pointerChain { if i > 0 { f.fs.Write(pointerChainBytes) } printHexPtr(f.fs, addr) } f.fs.Write(closeParenBytes) } // Display dereferenced value. switch { case nilFound == true: f.fs.Write(nilAngleBytes) case cycleFound == true: f.fs.Write(circularShortBytes) default: f.ignoreNextType = true f.format(ve) } } // format is the main workhorse for providing the Formatter interface. It // uses the passed reflect value to figure out what kind of object we are // dealing with and formats it appropriately. It is a recursive function, // however circular data structures are detected and handled properly. func (f *formatState) format(v reflect.Value) { // Handle invalid reflect values immediately. kind := v.Kind() if kind == reflect.Invalid { f.fs.Write(invalidAngleBytes) return } // Handle pointers specially. if kind == reflect.Ptr { f.formatPtr(v) return } // Print type information unless already handled elsewhere. if !f.ignoreNextType && f.fs.Flag('#') { f.fs.Write(openParenBytes) f.fs.Write([]byte(v.Type().String())) f.fs.Write(closeParenBytes) } f.ignoreNextType = false // Call Stringer/error interfaces if they exist and the handle methods // flag is enabled. if !f.cs.DisableMethods { if (kind != reflect.Invalid) && (kind != reflect.Interface) { if handled := handleMethods(f.cs, f.fs, v); handled { return } } } switch kind { case reflect.Invalid: // Do nothing. We should never get here since invalid has already // been handled above. case reflect.Bool: printBool(f.fs, v.Bool()) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: printInt(f.fs, v.Int(), 10) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: printUint(f.fs, v.Uint(), 10) case reflect.Float32: printFloat(f.fs, v.Float(), 32) case reflect.Float64: printFloat(f.fs, v.Float(), 64) case reflect.Complex64: printComplex(f.fs, v.Complex(), 32) case reflect.Complex128: printComplex(f.fs, v.Complex(), 64) case reflect.Slice: if v.IsNil() { f.fs.Write(nilAngleBytes) break } fallthrough case reflect.Array: f.fs.Write(openBracketBytes) f.depth++ if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) { f.fs.Write(maxShortBytes) } else { numEntries := v.Len() for i := 0; i < numEntries; i++ { if i > 0 { f.fs.Write(spaceBytes) } f.ignoreNextType = true f.format(f.unpackValue(v.Index(i))) } } f.depth-- f.fs.Write(closeBracketBytes) case reflect.String: f.fs.Write([]byte(v.String())) case reflect.Interface: // The only time we should get here is for nil interfaces due to // unpackValue calls. if v.IsNil() { f.fs.Write(nilAngleBytes) } case reflect.Ptr: // Do nothing. We should never get here since pointers have already // been handled above. case reflect.Map: // nil maps should be indicated as different than empty maps if v.IsNil() { f.fs.Write(nilAngleBytes) break } f.fs.Write(openMapBytes) f.depth++ if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) { f.fs.Write(maxShortBytes) } else { keys := v.MapKeys() if f.cs.SortKeys { sortValues(keys, f.cs) } for i, key := range keys { if i > 0 { f.fs.Write(spaceBytes) } f.ignoreNextType = true f.format(f.unpackValue(key)) f.fs.Write(colonBytes) f.ignoreNextType = true f.format(f.unpackValue(v.MapIndex(key))) } } f.depth-- f.fs.Write(closeMapBytes) case reflect.Struct: numFields := v.NumField() f.fs.Write(openBraceBytes) f.depth++ if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) { f.fs.Write(maxShortBytes) } else { vt := v.Type() for i := 0; i < numFields; i++ { if i > 0 { f.fs.Write(spaceBytes) } vtf := vt.Field(i) if f.fs.Flag('+') || f.fs.Flag('#') { f.fs.Write([]byte(vtf.Name)) f.fs.Write(colonBytes) } f.format(f.unpackValue(v.Field(i))) } } f.depth-- f.fs.Write(closeBraceBytes) case reflect.Uintptr: printHexPtr(f.fs, uintptr(v.Uint())) case reflect.UnsafePointer, reflect.Chan, reflect.Func: printHexPtr(f.fs, v.Pointer()) // There were not any other types at the time this code was written, but // fall back to letting the default fmt package handle it if any get added. default: format := f.buildDefaultFormat() if v.CanInterface() { fmt.Fprintf(f.fs, format, v.Interface()) } else { fmt.Fprintf(f.fs, format, v.String()) } } } // Format satisfies the fmt.Formatter interface. See NewFormatter for usage // details. func (f *formatState) Format(fs fmt.State, verb rune) { f.fs = fs // Use standard formatting for verbs that are not v. if verb != 'v' { format := f.constructOrigFormat(verb) fmt.Fprintf(fs, format, f.value) return } if f.value == nil { if fs.Flag('#') { fs.Write(interfaceBytes) } fs.Write(nilAngleBytes) return } f.format(reflect.ValueOf(f.value)) } // newFormatter is a helper function to consolidate the logic from the various // public methods which take varying config states. func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter { fs := &formatState{value: v, cs: cs} fs.pointers = make(map[uintptr]int) return fs } /* NewFormatter returns a custom formatter that satisfies the fmt.Formatter interface. As a result, it integrates cleanly with standard fmt package printing functions. The formatter is useful for inline printing of smaller data types similar to the standard %v format specifier. The custom formatter only responds to the %v (most compact), %+v (adds pointer addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb combinations. Any other verbs such as %x and %q will be sent to the the standard fmt package for formatting. In addition, the custom formatter ignores the width and precision arguments (however they will still work on the format specifiers not handled by the custom formatter). Typically this function shouldn't be called directly. It is much easier to make use of the custom formatter by calling one of the convenience functions such as Printf, Println, or Fprintf. */ func NewFormatter(v interface{}) fmt.Formatter { return newFormatter(&Config, v) } ================================================ FILE: vendor/github.com/davecgh/go-spew/spew/spew.go ================================================ /* * Copyright (c) 2013 Dave Collins * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ package spew import ( "fmt" "io" ) // Errorf is a wrapper for fmt.Errorf that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the formatted string as a value that satisfies error. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b)) func Errorf(format string, a ...interface{}) (err error) { return fmt.Errorf(format, convertArgs(a)...) } // Fprint is a wrapper for fmt.Fprint that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b)) func Fprint(w io.Writer, a ...interface{}) (n int, err error) { return fmt.Fprint(w, convertArgs(a)...) } // Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b)) func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { return fmt.Fprintf(w, format, convertArgs(a)...) } // Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it // passed with a default Formatter interface returned by NewFormatter. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b)) func Fprintln(w io.Writer, a ...interface{}) (n int, err error) { return fmt.Fprintln(w, convertArgs(a)...) } // Print is a wrapper for fmt.Print that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b)) func Print(a ...interface{}) (n int, err error) { return fmt.Print(convertArgs(a)...) } // Printf is a wrapper for fmt.Printf that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b)) func Printf(format string, a ...interface{}) (n int, err error) { return fmt.Printf(format, convertArgs(a)...) } // Println is a wrapper for fmt.Println that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the number of bytes written and any write error encountered. See // NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b)) func Println(a ...interface{}) (n int, err error) { return fmt.Println(convertArgs(a)...) } // Sprint is a wrapper for fmt.Sprint that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the resulting string. See NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b)) func Sprint(a ...interface{}) string { return fmt.Sprint(convertArgs(a)...) } // Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were // passed with a default Formatter interface returned by NewFormatter. It // returns the resulting string. See NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b)) func Sprintf(format string, a ...interface{}) string { return fmt.Sprintf(format, convertArgs(a)...) } // Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it // were passed with a default Formatter interface returned by NewFormatter. It // returns the resulting string. See NewFormatter for formatting details. // // This function is shorthand for the following syntax: // // fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b)) func Sprintln(a ...interface{}) string { return fmt.Sprintln(convertArgs(a)...) } // convertArgs accepts a slice of arguments and returns a slice of the same // length with each argument converted to a default spew Formatter interface. func convertArgs(args []interface{}) (formatters []interface{}) { formatters = make([]interface{}, len(args)) for index, arg := range args { formatters[index] = NewFormatter(arg) } return formatters } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/AUTHORS ================================================ # This is the official list of Go-MySQL-Driver authors for copyright purposes. # If you are submitting a patch, please add your name or the name of the # organization which holds the copyright to this list in alphabetical order. # Names should be added to this file as # Name # The email address is not required for organizations. # Please keep the list sorted. # Individual Persons Aaron Hopkins Achille Roussel Alexey Palazhchenko Andrew Reid Arne Hormann Asta Xie Bulat Gaifullin Carlos Nieto Chris Moos Daniel Montoya Daniel Nichter Daniël van Eeden Dave Protasowski DisposaBoy Egor Smolyakov Evan Shaw Frederick Mayle Gustavo Kristic Hanno Braun Henri Yandell Hirotaka Yamamoto ICHINOSE Shogo INADA Naoki Jacek Szwec James Harr Jeff Hodges Jeffrey Charles Jian Zhen Joshua Prunier Julien Lefevre Julien Schmidt Justin Li Justin Nuß Kamil Dziedzic Kevin Malachowski Kieron Woodhouse Lennart Rudolph Leonardo YongUk Kim Linh Tran Tuan Lion Yang Luca Looz Lucas Liu Luke Scott Maciej Zimnoch Michael Woolnough Nicola Peduzzi Olivier Mengué oscarzhao Paul Bonser Peter Schultz Rebecca Chin Reed Allman Robert Russell Runrioter Wung Shuode Li Soroush Pour Stan Putrya Stanley Gunawan Xiangyu Hu Xiaobing Jiang Xiuming Chen Zhenye Xie # Organizations Barracuda Networks, Inc. Counting Ltd. Google Inc. InfoSum Ltd. Keybase Inc. Percona LLC Pivotal Inc. Stripe Inc. ================================================ FILE: vendor/github.com/go-sql-driver/mysql/CHANGELOG.md ================================================ ## Version 1.3 (2016-12-01) Changes: - Go 1.1 is no longer supported - Use decimals fields in MySQL to format time types (#249) - Buffer optimizations (#269) - TLS ServerName defaults to the host (#283) - Refactoring (#400, #410, #437) - Adjusted documentation for second generation CloudSQL (#485) - Documented DSN system var quoting rules (#502) - Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512) New Features: - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) - Support for returning table alias on Columns() (#289, #359, #382) - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) - Support for uint64 parameters with high bit set (#332, #345) - Cleartext authentication plugin support (#327) - Exported ParseDSN function and the Config struct (#403, #419, #429) - Read / Write timeouts (#401) - Support for JSON field type (#414) - Support for multi-statements and multi-results (#411, #431) - DSN parameter to set the driver-side max_allowed_packet value manually (#489) - Native password authentication plugin support (#494, #524) Bugfixes: - Fixed handling of queries without columns and rows (#255) - Fixed a panic when SetKeepAlive() failed (#298) - Handle ERR packets while reading rows (#321) - Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349) - Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356) - Actually zero out bytes in handshake response (#378) - Fixed race condition in registering LOAD DATA INFILE handler (#383) - Fixed tests with MySQL 5.7.9+ (#380) - QueryUnescape TLS config names (#397) - Fixed "broken pipe" error by writing to closed socket (#390) - Fixed LOAD LOCAL DATA INFILE buffering (#424) - Fixed parsing of floats into float64 when placeholders are used (#434) - Fixed DSN tests with Go 1.7+ (#459) - Handle ERR packets while waiting for EOF (#473) - Invalidate connection on error while discarding additional results (#513) - Allow terminating packets of length 0 (#516) ## Version 1.2 (2014-06-03) Changes: - We switched back to a "rolling release". `go get` installs the current master branch again - Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver - Exported errors to allow easy checking from application code - Enabled TCP Keepalives on TCP connections - Optimized INFILE handling (better buffer size calculation, lazy init, ...) - The DSN parser also checks for a missing separating slash - Faster binary date / datetime to string formatting - Also exported the MySQLWarning type - mysqlConn.Close returns the first error encountered instead of ignoring all errors - writePacket() automatically writes the packet size to the header - readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets New Features: - `RegisterDial` allows the usage of a custom dial function to establish the network connection - Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter - Logging of critical errors is configurable with `SetLogger` - Google CloudSQL support Bugfixes: - Allow more than 32 parameters in prepared statements - Various old_password fixes - Fixed TestConcurrent test to pass Go's race detection - Fixed appendLengthEncodedInteger for large numbers - Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo) ## Version 1.1 (2013-11-02) Changes: - Go-MySQL-Driver now requires Go 1.1 - Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors - `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")` - DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'. - Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries - Optimized the buffer for reading - stmt.Query now caches column metadata - New Logo - Changed the copyright header to include all contributors - Improved the LOAD INFILE documentation - The driver struct is now exported to make the driver directly accessible - Refactored the driver tests - Added more benchmarks and moved all to a separate file - Other small refactoring New Features: - Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure - Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs - Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used Bugfixes: - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification - Convert to DB timezone when inserting `time.Time` - Splitted packets (more than 16MB) are now merged correctly - Fixed false positive `io.EOF` errors when the data was fully read - Avoid panics on reuse of closed connections - Fixed empty string producing false nil values - Fixed sign byte for positive TIME fields ## Version 1.0 (2013-05-14) Initial Release ================================================ FILE: vendor/github.com/go-sql-driver/mysql/CONTRIBUTING.md ================================================ # Contributing Guidelines ## Reporting Issues Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed). ## Contributing Code By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file. Don't forget to add yourself to the AUTHORS file. ### Code Review Everyone is invited to review and comment on pull requests. If it looks fine to you, comment with "LGTM" (Looks good to me). If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes. Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM". ## Development Ideas If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page. ================================================ FILE: vendor/github.com/go-sql-driver/mysql/LICENSE ================================================ Mozilla Public License Version 2.0 ================================== 1. Definitions -------------- 1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 1.3. "Contribution" means Covered Software of a particular Contributor. 1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. "Incompatible With Secondary Licenses" means (a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or (b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. "Executable Form" means any form of the work other than Source Code Form. 1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" means this document. 1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. "Modifications" means any of the following: (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or (b) any new file in Source Code Form that contains any Covered Software. 1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. "Source Code Form" means the form of the work preferred for making modifications. 1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions -------------------------------- 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and (b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: (a) for any code that a Contributor has removed from Covered Software; or (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities ------------------- 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: (a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and (b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation --------------------------------------------------- If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination -------------- 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. ************************************************************************ * * * 6. Disclaimer of Warranty * * ------------------------- * * * * Covered Software is provided under this License on an "as is" * * basis, without warranty of any kind, either expressed, implied, or * * statutory, including, without limitation, warranties that the * * Covered Software is free of defects, merchantable, fit for a * * particular purpose or non-infringing. The entire risk as to the * * quality and performance of the Covered Software is with You. * * Should any Covered Software prove defective in any respect, You * * (not any Contributor) assume the cost of any necessary servicing, * * repair, or correction. This disclaimer of warranty constitutes an * * essential part of this License. No use of any Covered Software is * * authorized under this License except under this disclaimer. * * * ************************************************************************ ************************************************************************ * * * 7. Limitation of Liability * * -------------------------- * * * * Under no circumstances and under no legal theory, whether tort * * (including negligence), contract, or otherwise, shall any * * Contributor, or anyone who distributes Covered Software as * * permitted above, be liable to You for any direct, indirect, * * special, incidental, or consequential damages of any character * * including, without limitation, damages for lost profits, loss of * * goodwill, work stoppage, computer failure or malfunction, or any * * and all other commercial damages or losses, even if such party * * shall have been informed of the possibility of such damages. This * * limitation of liability shall not apply to liability for death or * * personal injury resulting from such party's negligence to the * * extent applicable law prohibits such limitation. Some * * jurisdictions do not allow the exclusion or limitation of * * incidental or consequential damages, so this exclusion and * * limitation may not apply to You. * * * ************************************************************************ 8. Litigation ------------- Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 9. Miscellaneous ---------------- This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License --------------------------- 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice ------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - "Incompatible With Secondary Licenses" Notice --------------------------------------------------------- This Source Code Form is "Incompatible With Secondary Licenses", as defined by the Mozilla Public License, v. 2.0. ================================================ FILE: vendor/github.com/go-sql-driver/mysql/README.md ================================================ # Go-MySQL-Driver A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin") --------------------------------------- * [Features](#features) * [Requirements](#requirements) * [Installation](#installation) * [Usage](#usage) * [DSN (Data Source Name)](#dsn-data-source-name) * [Password](#password) * [Protocol](#protocol) * [Address](#address) * [Parameters](#parameters) * [Examples](#examples) * [Connection pool and timeouts](#connection-pool-and-timeouts) * [context.Context Support](#contextcontext-support) * [ColumnType Support](#columntype-support) * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) * [time.Time support](#timetime-support) * [Unicode support](#unicode-support) * [Testing / Development](#testing--development) * [License](#license) --------------------------------------- ## Features * Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance") * Native Go implementation. No C-bindings, just pure Go * Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc) * Automatic handling of broken connections * Automatic Connection Pooling *(by database/sql package)* * Supports queries larger than 16MB * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. * Intelligent `LONG DATA` handling in prepared statements * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation ## Requirements * Go 1.7 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- ## Installation Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: ```bash $ go get -u github.com/go-sql-driver/mysql ``` Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. ## Usage _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: ```go import "database/sql" import _ "github.com/go-sql-driver/mysql" db, err := sql.Open("mysql", "user:password@/dbname") ``` [Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples"). ### DSN (Data Source Name) The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets): ``` [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] ``` A DSN in its fullest form: ``` username:password@protocol(address)/dbname?param=value ``` Except for the databasename, all values are optional. So the minimal DSN is: ``` /dbname ``` If you do not want to preselect a database, leave `dbname` empty: ``` / ``` This has the same effect as an empty DSN string: ``` ``` Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. #### Password Passwords can consist of any character. Escaping is **not** necessary. #### Protocol See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. In general you should use an Unix domain socket if available and TCP otherwise for best performance. #### Address For TCP and UDP networks, addresses have the form `host[:port]`. If `port` is omitted, the default port will be used. If `host` is a literal IPv6 address, it must be enclosed in square brackets. The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form. For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`. #### Parameters *Parameters are case-sensitive!* Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`. ##### `allowAllFiles` ``` Type: bool Valid Values: true, false Default: false ``` `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) ##### `allowCleartextPasswords` ``` Type: bool Valid Values: true, false Default: false ``` `allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. ##### `allowNativePasswords` ``` Type: bool Valid Values: true, false Default: true ``` `allowNativePasswords=false` disallows the usage of MySQL native password method. ##### `allowOldPasswords` ``` Type: bool Valid Values: true, false Default: false ``` `allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords). ##### `charset` ``` Type: string Valid Values: Default: none ``` Sets the charset used for client-server interaction (`"SET NAMES "`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). Usage of the `charset` parameter is discouraged because it issues additional queries to the server. Unless you need the fallback behavior, please use `collation` instead. ##### `collation` ``` Type: string Valid Values: Default: utf8_general_ci ``` Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail. A list of valid charsets for a server is retrievable with `SHOW COLLATION`. ##### `clientFoundRows` ``` Type: bool Valid Values: true, false Default: false ``` `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed. ##### `columnsWithAlias` ``` Type: bool Valid Values: true, false Default: false ``` When `columnsWithAlias` is true, calls to `sql.Rows.Columns()` will return the table alias and the column name separated by a dot. For example: ``` SELECT u.id FROM users as u ``` will return `u.id` instead of just `id` if `columnsWithAlias=true`. ##### `interpolateParams` ``` Type: bool Valid Values: true, false Default: false ``` If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. *This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* ##### `loc` ``` Type: string Valid Values: Default: UTC ``` Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details. Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter. Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. ##### `maxAllowedPacket` ``` Type: decimal number Default: 4194304 ``` Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. ##### `multiStatements` ``` Type: bool Valid Values: true, false Default: false ``` Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. When `multiStatements` is used, `?` parameters must only be used in the first statement. ##### `parseTime` ``` Type: bool Valid Values: true, false Default: false ``` `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` ##### `readTimeout` ``` Type: duration Default: 0 ``` I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### `rejectReadOnly` ``` Type: bool Valid Values: true, false Default: false ``` `rejectReadOnly=true` causes the driver to reject read-only connections. This is for a possible race condition during an automatic failover, where the mysql client gets connected to a read-only replica after the failover. Note that this should be a fairly rare case, as an automatic failover normally happens when the primary is down, and the race condition shouldn't happen unless it comes back up online as soon as the failover is kicked off. On the other hand, when this happens, a MySQL application can get stuck on a read-only connection until restarted. It is however fairly easy to reproduce, for example, using a manual failover on AWS Aurora's MySQL-compatible cluster. If you are not relying on read-only transactions to reject writes that aren't supposed to happen, setting this on some MySQL providers (such as AWS Aurora) is safer for failovers. Note that ERROR 1290 can be returned for a `read-only` server and this option will cause a retry for that error. However the same error number is used for some other cases. You should ensure your application will never cause an ERROR 1290 except for `read-only` mode when enabling this option. ##### `timeout` ``` Type: duration Default: OS default ``` Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### `tls` ``` Type: bool / string Valid Values: true, false, skip-verify, Default: false ``` `tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). ##### `writeTimeout` ``` Type: duration Default: 0 ``` I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. ##### System Variables Any other parameters are interpreted as system variables: * `=`: `SET =` * `=`: `SET =` * `=%27%27`: `SET =''` Rules: * The values for string variables must be quoted with `'`. * The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed! (which implies values of string variables must be wrapped with `%27`). Examples: * `autocommit=1`: `SET autocommit=1` * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` #### Examples ``` user@unix(/path/to/socket)/dbname ``` ``` root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local ``` ``` user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true ``` Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html): ``` user:password@/dbname?sql_mode=TRADITIONAL ``` TCP via IPv6: ``` user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci ``` TCP on a remote host, e.g. Amazon RDS: ``` id:password@tcp(your-amazonaws-uri.com:3306)/dbname ``` Google Cloud SQL on App Engine (First Generation MySQL Server): ``` user@cloudsql(project-id:instance-name)/dbname ``` Google Cloud SQL on App Engine (Second Generation MySQL Server): ``` user@cloudsql(project-id:regionname:instance-name)/dbname ``` TCP using default port (3306) on localhost: ``` user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped ``` Use the default protocol (tcp) and host (localhost:3306): ``` user:password@/dbname ``` No Database preselected: ``` user:password@/ ``` ### Connection pool and timeouts The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. ## `context.Context` Support Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): ```go import "github.com/go-sql-driver/mysql" ``` Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details. ### `time.Time` support The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. ### Unicode support Since version 1.1 Go-MySQL-Driver automatically uses the collation `utf8_general_ci` by default. Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls). See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details. --------------------------------------- ## License Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE) Mozilla summarizes the license scope as follows: > MPL: The copyleft applies to any files containing MPLed code. That means: * You can **use** the **unchanged** source code both in private and commercially. * When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0). * You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**. Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license. You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). ![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") ================================================ FILE: vendor/github.com/go-sql-driver/mysql/appengine.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. // +build appengine package mysql import ( "google.golang.org/appengine/cloudsql" ) func init() { RegisterDial("cloudsql", cloudsql.Dial) } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/buffer.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "io" "net" "time" ) const defaultBufSize = 4096 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { buf []byte nc net.Conn idx int length int timeout time.Duration } func newBuffer(nc net.Conn) buffer { var b [defaultBufSize]byte return buffer{ buf: b[:], nc: nc, } } // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length // move existing data to the beginning if n > 0 && b.idx > 0 { copy(b.buf[0:n], b.buf[b.idx:]) } // grow buffer if necessary // TODO: let the buffer shrink again at some point // Maybe keep the org buf slice and swap back? if need > len(b.buf) { // Round up to the next multiple of the default size newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) copy(newBuf, b.buf) b.buf = newBuf } b.idx = 0 for { if b.timeout > 0 { if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { return err } } nn, err := b.nc.Read(b.buf[n:]) n += nn switch err { case nil: if n < need { continue } b.length = n return nil case io.EOF: if n >= need { b.length = n return nil } return io.ErrUnexpectedEOF default: return err } } } // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read func (b *buffer) readNext(need int) ([]byte, error) { if b.length < need { // refill if err := b.fill(need); err != nil { return nil, err } } offset := b.idx b.idx += need b.length -= need return b.buf[offset:b.idx], nil } // returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. func (b *buffer) takeBuffer(length int) []byte { if b.length > 0 { return nil } // test (cheap) general case first if length <= defaultBufSize || length <= cap(b.buf) { return b.buf[:length] } if length < maxPacketSize { b.buf = make([]byte, length) return b.buf } return make([]byte, length) } // shortcut which can be used if the requested buffer is guaranteed to be // smaller than defaultBufSize // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) []byte { if b.length == 0 { return b.buf[:length] } return nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. // Only one buffer (total) can be used at a time. func (b *buffer) takeCompleteBuffer() []byte { if b.length == 0 { return b.buf } return nil } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/collations.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql const defaultCollation = "utf8_general_ci" const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: // SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS var collations = map[string]byte{ "big5_chinese_ci": 1, "latin2_czech_cs": 2, "dec8_swedish_ci": 3, "cp850_general_ci": 4, "latin1_german1_ci": 5, "hp8_english_ci": 6, "koi8r_general_ci": 7, "latin1_swedish_ci": 8, "latin2_general_ci": 9, "swe7_swedish_ci": 10, "ascii_general_ci": 11, "ujis_japanese_ci": 12, "sjis_japanese_ci": 13, "cp1251_bulgarian_ci": 14, "latin1_danish_ci": 15, "hebrew_general_ci": 16, "tis620_thai_ci": 18, "euckr_korean_ci": 19, "latin7_estonian_cs": 20, "latin2_hungarian_ci": 21, "koi8u_general_ci": 22, "cp1251_ukrainian_ci": 23, "gb2312_chinese_ci": 24, "greek_general_ci": 25, "cp1250_general_ci": 26, "latin2_croatian_ci": 27, "gbk_chinese_ci": 28, "cp1257_lithuanian_ci": 29, "latin5_turkish_ci": 30, "latin1_german2_ci": 31, "armscii8_general_ci": 32, "utf8_general_ci": 33, "cp1250_czech_cs": 34, "ucs2_general_ci": 35, "cp866_general_ci": 36, "keybcs2_general_ci": 37, "macce_general_ci": 38, "macroman_general_ci": 39, "cp852_general_ci": 40, "latin7_general_ci": 41, "latin7_general_cs": 42, "macce_bin": 43, "cp1250_croatian_ci": 44, "utf8mb4_general_ci": 45, "utf8mb4_bin": 46, "latin1_bin": 47, "latin1_general_ci": 48, "latin1_general_cs": 49, "cp1251_bin": 50, "cp1251_general_ci": 51, "cp1251_general_cs": 52, "macroman_bin": 53, "utf16_general_ci": 54, "utf16_bin": 55, "utf16le_general_ci": 56, "cp1256_general_ci": 57, "cp1257_bin": 58, "cp1257_general_ci": 59, "utf32_general_ci": 60, "utf32_bin": 61, "utf16le_bin": 62, "binary": 63, "armscii8_bin": 64, "ascii_bin": 65, "cp1250_bin": 66, "cp1256_bin": 67, "cp866_bin": 68, "dec8_bin": 69, "greek_bin": 70, "hebrew_bin": 71, "hp8_bin": 72, "keybcs2_bin": 73, "koi8r_bin": 74, "koi8u_bin": 75, "latin2_bin": 77, "latin5_bin": 78, "latin7_bin": 79, "cp850_bin": 80, "cp852_bin": 81, "swe7_bin": 82, "utf8_bin": 83, "big5_bin": 84, "euckr_bin": 85, "gb2312_bin": 86, "gbk_bin": 87, "sjis_bin": 88, "tis620_bin": 89, "ucs2_bin": 90, "ujis_bin": 91, "geostd8_general_ci": 92, "geostd8_bin": 93, "latin1_spanish_ci": 94, "cp932_japanese_ci": 95, "cp932_bin": 96, "eucjpms_japanese_ci": 97, "eucjpms_bin": 98, "cp1250_polish_ci": 99, "utf16_unicode_ci": 101, "utf16_icelandic_ci": 102, "utf16_latvian_ci": 103, "utf16_romanian_ci": 104, "utf16_slovenian_ci": 105, "utf16_polish_ci": 106, "utf16_estonian_ci": 107, "utf16_spanish_ci": 108, "utf16_swedish_ci": 109, "utf16_turkish_ci": 110, "utf16_czech_ci": 111, "utf16_danish_ci": 112, "utf16_lithuanian_ci": 113, "utf16_slovak_ci": 114, "utf16_spanish2_ci": 115, "utf16_roman_ci": 116, "utf16_persian_ci": 117, "utf16_esperanto_ci": 118, "utf16_hungarian_ci": 119, "utf16_sinhala_ci": 120, "utf16_german2_ci": 121, "utf16_croatian_ci": 122, "utf16_unicode_520_ci": 123, "utf16_vietnamese_ci": 124, "ucs2_unicode_ci": 128, "ucs2_icelandic_ci": 129, "ucs2_latvian_ci": 130, "ucs2_romanian_ci": 131, "ucs2_slovenian_ci": 132, "ucs2_polish_ci": 133, "ucs2_estonian_ci": 134, "ucs2_spanish_ci": 135, "ucs2_swedish_ci": 136, "ucs2_turkish_ci": 137, "ucs2_czech_ci": 138, "ucs2_danish_ci": 139, "ucs2_lithuanian_ci": 140, "ucs2_slovak_ci": 141, "ucs2_spanish2_ci": 142, "ucs2_roman_ci": 143, "ucs2_persian_ci": 144, "ucs2_esperanto_ci": 145, "ucs2_hungarian_ci": 146, "ucs2_sinhala_ci": 147, "ucs2_german2_ci": 148, "ucs2_croatian_ci": 149, "ucs2_unicode_520_ci": 150, "ucs2_vietnamese_ci": 151, "ucs2_general_mysql500_ci": 159, "utf32_unicode_ci": 160, "utf32_icelandic_ci": 161, "utf32_latvian_ci": 162, "utf32_romanian_ci": 163, "utf32_slovenian_ci": 164, "utf32_polish_ci": 165, "utf32_estonian_ci": 166, "utf32_spanish_ci": 167, "utf32_swedish_ci": 168, "utf32_turkish_ci": 169, "utf32_czech_ci": 170, "utf32_danish_ci": 171, "utf32_lithuanian_ci": 172, "utf32_slovak_ci": 173, "utf32_spanish2_ci": 174, "utf32_roman_ci": 175, "utf32_persian_ci": 176, "utf32_esperanto_ci": 177, "utf32_hungarian_ci": 178, "utf32_sinhala_ci": 179, "utf32_german2_ci": 180, "utf32_croatian_ci": 181, "utf32_unicode_520_ci": 182, "utf32_vietnamese_ci": 183, "utf8_unicode_ci": 192, "utf8_icelandic_ci": 193, "utf8_latvian_ci": 194, "utf8_romanian_ci": 195, "utf8_slovenian_ci": 196, "utf8_polish_ci": 197, "utf8_estonian_ci": 198, "utf8_spanish_ci": 199, "utf8_swedish_ci": 200, "utf8_turkish_ci": 201, "utf8_czech_ci": 202, "utf8_danish_ci": 203, "utf8_lithuanian_ci": 204, "utf8_slovak_ci": 205, "utf8_spanish2_ci": 206, "utf8_roman_ci": 207, "utf8_persian_ci": 208, "utf8_esperanto_ci": 209, "utf8_hungarian_ci": 210, "utf8_sinhala_ci": 211, "utf8_german2_ci": 212, "utf8_croatian_ci": 213, "utf8_unicode_520_ci": 214, "utf8_vietnamese_ci": 215, "utf8_general_mysql500_ci": 223, "utf8mb4_unicode_ci": 224, "utf8mb4_icelandic_ci": 225, "utf8mb4_latvian_ci": 226, "utf8mb4_romanian_ci": 227, "utf8mb4_slovenian_ci": 228, "utf8mb4_polish_ci": 229, "utf8mb4_estonian_ci": 230, "utf8mb4_spanish_ci": 231, "utf8mb4_swedish_ci": 232, "utf8mb4_turkish_ci": 233, "utf8mb4_czech_ci": 234, "utf8mb4_danish_ci": 235, "utf8mb4_lithuanian_ci": 236, "utf8mb4_slovak_ci": 237, "utf8mb4_spanish2_ci": 238, "utf8mb4_roman_ci": 239, "utf8mb4_persian_ci": 240, "utf8mb4_esperanto_ci": 241, "utf8mb4_hungarian_ci": 242, "utf8mb4_sinhala_ci": 243, "utf8mb4_german2_ci": 244, "utf8mb4_croatian_ci": 245, "utf8mb4_unicode_520_ci": 246, "utf8mb4_vietnamese_ci": 247, } // A blacklist of collations which is unsafe to interpolate parameters. // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. var unsafeCollations = map[string]bool{ "big5_chinese_ci": true, "sjis_japanese_ci": true, "gbk_chinese_ci": true, "big5_bin": true, "gb2312_bin": true, "gbk_bin": true, "sjis_bin": true, "cp932_japanese_ci": true, "cp932_bin": true, } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/connection.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "database/sql/driver" "io" "net" "strconv" "strings" "time" ) // a copy of context.Context for Go 1.7 and earlier type mysqlContext interface { Done() <-chan struct{} Err() error // defined in context.Context, but not used in this driver: // Deadline() (deadline time.Time, ok bool) // Value(key interface{}) interface{} } type mysqlConn struct { buf buffer netConn net.Conn affectedRows uint64 insertId uint64 cfg *Config maxAllowedPacket int maxWriteSize int writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 parseTime bool // for context support (Go 1.8+) watching bool watcher chan<- mysqlContext closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { switch param { // Charset case "charset": charsets := strings.Split(val, ",") for i := range charsets { // ignore errors here - a charset may not exist err = mc.exec("SET NAMES " + charsets[i]) if err == nil { break } } if err != nil { return } // System Vars default: err = mc.exec("SET " + param + "=" + val + "") if err != nil { return } } } return } func (mc *mysqlConn) markBadConn(err error) error { if mc == nil { return err } if err != errBadConnNoWrite { return err } return driver.ErrBadConn } func (mc *mysqlConn) Begin() (driver.Tx, error) { return mc.begin(false) } func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } var q string if readOnly { q = "START TRANSACTION READ ONLY" } else { q = "START TRANSACTION" } err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) } mc.cleanup() return } // Closes the network connection and unsets internal variables. Do not call this // function after successfully authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { if !mc.closed.TrySet(true) { return } // Makes cleanup idempotent close(mc.closech) if mc.netConn == nil { return } if err := mc.netConn.Close(); err != nil { errLog.Print(err) } } func (mc *mysqlConn) error() error { if mc.closed.IsSet() { if err := mc.canceled.Value(); err != nil { return err } return ErrInvalidConn } return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { return nil, mc.markBadConn(err) } stmt := &mysqlStmt{ mc: mc, } // Read Result columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { if err = mc.readUntilEOF(); err != nil { return nil, err } } if columnCount > 0 { err = mc.readUntilEOF() } } return stmt, err } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { // Number of ? should be same to len(args) if strings.Count(query, "?") != len(args) { return "", driver.ErrSkip } buf := mc.buf.takeCompleteBuffer() if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return "", ErrInvalidConn } buf = buf[:0] argPos := 0 for i := 0; i < len(query); i++ { q := strings.IndexByte(query[i:], '?') if q == -1 { buf = append(buf, query[i:]...) break } buf = append(buf, query[i:i+q]...) i += q arg := args[argPos] argPos++ if arg == nil { buf = append(buf, "NULL"...) continue } switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: if v { buf = append(buf, '1') } else { buf = append(buf, '0') } case time.Time: if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { v := v.In(mc.cfg.Loc) v = v.Add(time.Nanosecond * 500) // To round under microsecond year := v.Year() year100 := year / 100 year1 := year % 100 month := v.Month() day := v.Day() hour := v.Hour() minute := v.Minute() second := v.Second() micro := v.Nanosecond() / 1000 buf = append(buf, []byte{ '\'', digits10[year100], digits01[year100], digits10[year1], digits01[year1], '-', digits10[month], digits01[month], '-', digits10[day], digits01[day], ' ', digits10[hour], digits01[hour], ':', digits10[minute], digits01[minute], ':', digits10[second], digits01[second], }...) if micro != 0 { micro10000 := micro / 10000 micro100 := micro / 100 % 100 micro1 := micro % 100 buf = append(buf, []byte{ '.', digits10[micro10000], digits01[micro10000], digits10[micro100], digits01[micro100], digits10[micro1], digits01[micro1], }...) } buf = append(buf, '\'') } case []byte: if v == nil { buf = append(buf, "NULL"...) } else { buf = append(buf, "_binary'"...) if mc.status&statusNoBackslashEscapes == 0 { buf = escapeBytesBackslash(buf, v) } else { buf = escapeBytesQuotes(buf, v) } buf = append(buf, '\'') } case string: buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { buf = escapeStringBackslash(buf, v) } else { buf = escapeStringQuotes(buf, v) } buf = append(buf, '\'') default: return "", driver.ErrSkip } if len(buf)+4 > mc.maxAllowedPacket { return "", driver.ErrSkip } } if argPos != len(args) { return "", driver.ErrSkip } return string(buf), nil } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } query = prepared } mc.affectedRows = 0 mc.insertId = 0 err := mc.exec(query) if err == nil { return &mysqlResult{ affectedRows: int64(mc.affectedRows), insertId: int64(mc.insertId), }, err } return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { return mc.markBadConn(err) } // Read Result resLen, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns if err := mc.readUntilEOF(); err != nil { return err } // rows if err := mc.readUntilEOF(); err != nil { return err } } return mc.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { return mc.query(query, args) } func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } query = prepared } // Send command err := mc.writeCommandPacketStr(comQuery, query) if err == nil { // Read Result var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc if resLen == 0 { rows.rs.done = true switch err := rows.NextResultSet(); err { case nil, io.EOF: return rows, nil default: return nil, err } } // Columns rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result resLen, err := mc.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns if err := mc.readUntilEOF(); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { return dest[0].([]byte), mc.readUntilEOF() } } return nil, err } // finish is called when the query has canceled. func (mc *mysqlConn) cancel(err error) { mc.canceled.Set(err) mc.cleanup() } // finish is called when the query has succeeded. func (mc *mysqlConn) finish() { if !mc.watching || mc.finished == nil { return } select { case mc.finished <- struct{}{}: mc.watching = false case <-mc.closech: } } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/connection_go18.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. // +build go1.8 package mysql import ( "context" "database/sql" "database/sql/driver" ) // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) error { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return driver.ErrBadConn } if err := mc.watchCancel(ctx); err != nil { return err } defer mc.finish() if err := mc.writeCommandPacket(comPing); err != nil { return err } if _, err := mc.readResultOK(); err != nil { return err } return nil } // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if err := mc.watchCancel(ctx); err != nil { return nil, err } defer mc.finish() if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { level, err := mapIsolationLevel(opts.Isolation) if err != nil { return nil, err } err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) if err != nil { return nil, err } } return mc.begin(opts.ReadOnly) } func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := mc.watchCancel(ctx); err != nil { return nil, err } rows, err := mc.query(query, dargs) if err != nil { mc.finish() return nil, err } rows.finish = mc.finish return rows, err } func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := mc.watchCancel(ctx); err != nil { return nil, err } defer mc.finish() return mc.Exec(query, dargs) } func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if err := mc.watchCancel(ctx); err != nil { return nil, err } stmt, err := mc.Prepare(query) mc.finish() if err != nil { return nil, err } select { default: case <-ctx.Done(): stmt.Close() return nil, ctx.Err() } return stmt, nil } func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := stmt.mc.watchCancel(ctx); err != nil { return nil, err } rows, err := stmt.query(dargs) if err != nil { stmt.mc.finish() return nil, err } rows.finish = stmt.mc.finish return rows, err } func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := stmt.mc.watchCancel(ctx); err != nil { return nil, err } defer stmt.mc.finish() return stmt.Exec(dargs) } func (mc *mysqlConn) watchCancel(ctx context.Context) error { if mc.watching { // Reach here if canceled, // so the connection is already invalid mc.cleanup() return nil } if ctx.Done() == nil { return nil } mc.watching = true select { default: case <-ctx.Done(): return ctx.Err() } if mc.watcher == nil { return nil } mc.watcher <- ctx return nil } func (mc *mysqlConn) startWatcher() { watcher := make(chan mysqlContext, 1) mc.watcher = watcher finished := make(chan struct{}) mc.finished = finished go func() { for { var ctx mysqlContext select { case ctx = <-watcher: case <-mc.closech: return } select { case <-ctx.Done(): mc.cancel(ctx.Err()) case <-finished: case <-mc.closech: return } } }() } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { nv.Value, err = converter{}.ConvertValue(nv.Value) return } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/const.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql const ( defaultMaxAllowedPacket = 4 << 20 // 4 MiB minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) // MySQL constants documentation: // http://dev.mysql.com/doc/internals/en/client-server-protocol.html const ( iOK byte = 0x00 iLocalInFile byte = 0xfb iEOF byte = 0xfe iERR byte = 0xff ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags type clientFlag uint32 const ( clientLongPassword clientFlag = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB clientNoSchema clientCompress clientODBC clientLocalFiles clientIgnoreSpace clientProtocol41 clientInteractive clientSSL clientIgnoreSIGPIPE clientTransactions clientReserved clientSecureConn clientMultiStatements clientMultiResults clientPSMultiResults clientPluginAuth clientConnectAttrs clientPluginAuthLenEncClientData clientCanHandleExpiredPasswords clientSessionTrack clientDeprecateEOF ) const ( comQuit byte = iota + 1 comInitDB comQuery comFieldList comCreateDB comDropDB comRefresh comShutdown comStatistics comProcessInfo comConnect comProcessKill comDebug comPing comTime comDelayedInsert comChangeUser comBinlogDump comTableDump comConnectOut comRegisterSlave comStmtPrepare comStmtExecute comStmtSendLongData comStmtClose comStmtReset comSetOption comStmtFetch ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType type fieldType byte const ( fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong fieldTypeFloat fieldTypeDouble fieldTypeNULL fieldTypeTimestamp fieldTypeLongLong fieldTypeInt24 fieldTypeDate fieldTypeTime fieldTypeDateTime fieldTypeYear fieldTypeNewDate fieldTypeVarChar fieldTypeBit ) const ( fieldTypeJSON fieldType = iota + 0xf5 fieldTypeNewDecimal fieldTypeEnum fieldTypeSet fieldTypeTinyBLOB fieldTypeMediumBLOB fieldTypeLongBLOB fieldTypeBLOB fieldTypeVarString fieldTypeString fieldTypeGeometry ) type fieldFlag uint16 const ( flagNotNULL fieldFlag = 1 << iota flagPriKey flagUniqueKey flagMultipleKey flagBLOB flagUnsigned flagZeroFill flagBinary flagEnum flagAutoIncrement flagTimestamp flagSet flagUnknown1 flagUnknown2 flagUnknown3 flagUnknown4 ) // http://dev.mysql.com/doc/internals/en/status-flags.html type statusFlag uint16 const ( statusInTrans statusFlag = 1 << iota statusInAutocommit statusReserved // Not in documentation statusMoreResultsExists statusNoGoodIndexUsed statusNoIndexUsed statusCursorExists statusLastRowSent statusDbDropped statusNoBackslashEscapes statusMetadataChanged statusQueryWasSlow statusPsOutParams statusInTransReadonly statusSessionStateChanged ) ================================================ FILE: vendor/github.com/go-sql-driver/mysql/driver.go ================================================ // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. // Package mysql provides a MySQL driver for Go's database/sql package. // // The driver should be used via the database/sql package: // // import "database/sql" // import _ "github.com/go-sql-driver/mysql" // // db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details package mysql import ( "database/sql" "database/sql/driver" "net" "sync" ) // watcher interface is used for context support (From Go 1.8) type watcher interface { startWatcher() } // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial type DialFunc func(addr string) (net.Conn, error) var ( dialsLock sync.RWMutex dials map[string]DialFunc ) // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. func RegisterDial(net string, dial DialFunc) { dialsLock.Lock() defer dialsLock.Unlock() if dials == nil { dials = make(map[string]DialFunc) } dials[net] = dial } // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how // the DSN string is formated func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { var err error // New mysqlConn mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), } mc.cfg, err = ParseDSN(dsn) if err != nil { return nil, err } mc.parseTime = mc.cfg.ParseTime // Connect to Server dialsLock.RLock() dial, ok := dials[mc.cfg.Net] dialsLock.RUnlock() if ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { return nil, err } // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { if err := tc.SetKeepAlive(true); err != nil { // Don't send COM_QUIT before handshake. mc.netConn.Close() mc.netConn = nil return nil, err } } // Call startWatcher for context support (From Go 1.8) if s, ok := interface{}(mc).(watcher); ok { s.startWatcher() } mc.buf = newBuffer(mc.netConn) // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet cipher, err := mc.readInitPacket() if err != nil { mc.cleanup() return nil, err } // Send Client Authentication Packet if err = mc.writeAuthPacket(cipher); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible if err = handleAuthResult(mc, cipher); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. mc.cleanup() return nil, err } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { // Get max allowed packet size maxap, err := mc.getSystemVar("max_allowed_packet") if err != nil { mc.Close() return nil, err } mc.maxAllowedPacket = stringToInt(maxap) - 1 } if mc.maxAllowedPacket < maxPacketSize { mc.maxWriteSize = mc.maxAllowedPacket } // Handle DSN Params err = mc.handleParams() if err != nil { mc.Close() return nil, err } return mc, nil } func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { // Read Result Packet cipher, err := mc.readResultOK() if err == nil { return nil // auth successful } if mc.cfg == nil { return err // auth failed and retry not possible } // Retry auth if configured to do so. if mc.cfg.AllowOldPasswords && err == ErrOldPassword { // Retry with old authentication method. Note: there are edge cases // where this should work but doesn't; this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is // sent and we have to keep using the cipher sent in the init packet. if cipher == nil { cipher = oldCipher } if err = mc.writeOldAuthPacket(cipher); err != nil { return err } _, err = mc.readResultOK() } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html if err = mc.writeClearAuthPacket(); err != nil { return err } _, err = mc.readResultOK() } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { if err = mc.writeNativeAuthPacket(cipher); err != nil { return err } _, err = mc.readResultOK() } return err } func init() { sql.Register("mysql", &MySQLDriver{}) } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/dsn.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "bytes" "crypto/tls" "errors" "fmt" "net" "net/url" "sort" "strconv" "strings" "time" ) var ( errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) // Config is a configuration parsed from a DSN string. // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { User string // Username Passwd string // Password (requires User) Net string // Network type Addr string // Network address (requires Net) DBName string // Database name Params map[string]string // Connection parameters Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin AllowNativePasswords bool // Allows the native password authentication method AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections } // NewConfig creates a new Config and sets default values. func NewConfig() *Config { return &Config{ Collation: defaultCollation, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, } } func (cfg *Config) normalize() error { if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { return errInvalidDSNUnsafeCollation } // Set default network if empty if cfg.Net == "" { cfg.Net = "tcp" } // Set default address if empty if cfg.Addr == "" { switch cfg.Net { case "tcp": cfg.Addr = "127.0.0.1:3306" case "unix": cfg.Addr = "/tmp/mysql.sock" default: return errors.New("default addr for network '" + cfg.Net + "' unknown") } } else if cfg.Net == "tcp" { cfg.Addr = ensureHavePort(cfg.Addr) } if cfg.tls != nil { if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { host, _, err := net.SplitHostPort(cfg.Addr) if err == nil { cfg.tls.ServerName = host } } } return nil } // FormatDSN formats the given Config into a DSN string which can be passed to // the driver. func (cfg *Config) FormatDSN() string { var buf bytes.Buffer // [username[:password]@] if len(cfg.User) > 0 { buf.WriteString(cfg.User) if len(cfg.Passwd) > 0 { buf.WriteByte(':') buf.WriteString(cfg.Passwd) } buf.WriteByte('@') } // [protocol[(address)]] if len(cfg.Net) > 0 { buf.WriteString(cfg.Net) if len(cfg.Addr) > 0 { buf.WriteByte('(') buf.WriteString(cfg.Addr) buf.WriteByte(')') } } // /dbname buf.WriteByte('/') buf.WriteString(cfg.DBName) // [?param1=value1&...¶mN=valueN] hasParam := false if cfg.AllowAllFiles { hasParam = true buf.WriteString("?allowAllFiles=true") } if cfg.AllowCleartextPasswords { if hasParam { buf.WriteString("&allowCleartextPasswords=true") } else { hasParam = true buf.WriteString("?allowCleartextPasswords=true") } } if !cfg.AllowNativePasswords { if hasParam { buf.WriteString("&allowNativePasswords=false") } else { hasParam = true buf.WriteString("?allowNativePasswords=false") } } if cfg.AllowOldPasswords { if hasParam { buf.WriteString("&allowOldPasswords=true") } else { hasParam = true buf.WriteString("?allowOldPasswords=true") } } if cfg.ClientFoundRows { if hasParam { buf.WriteString("&clientFoundRows=true") } else { hasParam = true buf.WriteString("?clientFoundRows=true") } } if col := cfg.Collation; col != defaultCollation && len(col) > 0 { if hasParam { buf.WriteString("&collation=") } else { hasParam = true buf.WriteString("?collation=") } buf.WriteString(col) } if cfg.ColumnsWithAlias { if hasParam { buf.WriteString("&columnsWithAlias=true") } else { hasParam = true buf.WriteString("?columnsWithAlias=true") } } if cfg.InterpolateParams { if hasParam { buf.WriteString("&interpolateParams=true") } else { hasParam = true buf.WriteString("?interpolateParams=true") } } if cfg.Loc != time.UTC && cfg.Loc != nil { if hasParam { buf.WriteString("&loc=") } else { hasParam = true buf.WriteString("?loc=") } buf.WriteString(url.QueryEscape(cfg.Loc.String())) } if cfg.MultiStatements { if hasParam { buf.WriteString("&multiStatements=true") } else { hasParam = true buf.WriteString("?multiStatements=true") } } if cfg.ParseTime { if hasParam { buf.WriteString("&parseTime=true") } else { hasParam = true buf.WriteString("?parseTime=true") } } if cfg.ReadTimeout > 0 { if hasParam { buf.WriteString("&readTimeout=") } else { hasParam = true buf.WriteString("?readTimeout=") } buf.WriteString(cfg.ReadTimeout.String()) } if cfg.RejectReadOnly { if hasParam { buf.WriteString("&rejectReadOnly=true") } else { hasParam = true buf.WriteString("?rejectReadOnly=true") } } if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") } else { hasParam = true buf.WriteString("?timeout=") } buf.WriteString(cfg.Timeout.String()) } if len(cfg.TLSConfig) > 0 { if hasParam { buf.WriteString("&tls=") } else { hasParam = true buf.WriteString("?tls=") } buf.WriteString(url.QueryEscape(cfg.TLSConfig)) } if cfg.WriteTimeout > 0 { if hasParam { buf.WriteString("&writeTimeout=") } else { hasParam = true buf.WriteString("?writeTimeout=") } buf.WriteString(cfg.WriteTimeout.String()) } if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { if hasParam { buf.WriteString("&maxAllowedPacket=") } else { hasParam = true buf.WriteString("?maxAllowedPacket=") } buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket)) } // other params if cfg.Params != nil { var params []string for param := range cfg.Params { params = append(params, param) } sort.Strings(params) for _, param := range params { if hasParam { buf.WriteByte('&') } else { hasParam = true buf.WriteByte('?') } buf.WriteString(param) buf.WriteByte('=') buf.WriteString(url.QueryEscape(cfg.Params[param])) } } return buf.String() } // ParseDSN parses the DSN string to a Config func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { foundSlash = true var j, k int // left part is empty if i <= 0 if i > 0 { // [username[:password]@][protocol[(address)]] // Find the last '@' in dsn[:i] for j = i; j >= 0; j-- { if dsn[j] == '@' { // username[:password] // Find the first ':' in dsn[:j] for k = 0; k < j; k++ { if dsn[k] == ':' { cfg.Passwd = dsn[k+1 : j] break } } cfg.User = dsn[:k] break } } // [protocol[(address)]] // Find the first '(' in dsn[j+1:i] for k = j + 1; k < i; k++ { if dsn[k] == '(' { // dsn[i-1] must be == ')' if an address is specified if dsn[i-1] != ')' { if strings.ContainsRune(dsn[k+1:i], ')') { return nil, errInvalidDSNUnescaped } return nil, errInvalidDSNAddr } cfg.Addr = dsn[k+1 : i-1] break } } cfg.Net = dsn[j+1 : k] } // dbname[?param1=value1&...¶mN=valueN] // Find the first '?' in dsn[i+1:] for j = i + 1; j < len(dsn); j++ { if dsn[j] == '?' { if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { return } break } } cfg.DBName = dsn[i+1 : j] break } } if !foundSlash && len(dsn) > 0 { return nil, errInvalidDSNNoSlash } if err = cfg.normalize(); err != nil { return nil, err } return } // parseDSNParams parses the DSN "query string" // Values must be url.QueryEscape'ed func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { continue } // cfg params switch value := param[1]; param[0] { // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool cfg.AllowAllFiles, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Use cleartext authentication mode (MySQL 5.5.10+) case "allowCleartextPasswords": var isBool bool cfg.AllowCleartextPasswords, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Use native password authentication case "allowNativePasswords": var isBool bool cfg.AllowNativePasswords, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Use old authentication mode (pre MySQL 4.1) case "allowOldPasswords": var isBool bool cfg.AllowOldPasswords, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Switch "rowsAffected" mode case "clientFoundRows": var isBool bool cfg.ClientFoundRows, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Collation case "collation": cfg.Collation = value break case "columnsWithAlias": var isBool bool cfg.ColumnsWithAlias, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Compression case "compress": return errors.New("compression not implemented yet") // Enable client side placeholder substitution case "interpolateParams": var isBool bool cfg.InterpolateParams, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Time Location case "loc": if value, err = url.QueryUnescape(value); err != nil { return } cfg.Loc, err = time.LoadLocation(value) if err != nil { return } // multiple statements in one query case "multiStatements": var isBool bool cfg.MultiStatements, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // time.Time parsing case "parseTime": var isBool bool cfg.ParseTime, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // I/O read Timeout case "readTimeout": cfg.ReadTimeout, err = time.ParseDuration(value) if err != nil { return } // Reject read-only connections case "rejectReadOnly": var isBool bool cfg.RejectReadOnly, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } // Strict mode case "strict": panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") // Dial Timeout case "timeout": cfg.Timeout, err = time.ParseDuration(value) if err != nil { return } // TLS-Encryption case "tls": boolValue, isBool := readBool(value) if isBool { if boolValue { cfg.TLSConfig = "true" cfg.tls = &tls.Config{} } else { cfg.TLSConfig = "false" } } else if vl := strings.ToLower(value); vl == "skip-verify" { cfg.TLSConfig = vl cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { name, err := url.QueryUnescape(value) if err != nil { return fmt.Errorf("invalid value for TLS config name: %v", err) } if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { cfg.TLSConfig = name cfg.tls = tlsConfig } else { return errors.New("invalid value / unknown config name: " + name) } } // I/O write Timeout case "writeTimeout": cfg.WriteTimeout, err = time.ParseDuration(value) if err != nil { return } case "maxAllowedPacket": cfg.MaxAllowedPacket, err = strconv.Atoi(value) if err != nil { return } default: // lazy init if cfg.Params == nil { cfg.Params = make(map[string]string) } if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { return } } } return } func ensureHavePort(addr string) string { if _, _, err := net.SplitHostPort(addr); err != nil { return net.JoinHostPort(addr, "3306") } return addr } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/errors.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "errors" "fmt" "log" "os" ) // Various errors the driver might return. Can change between driver versions. var ( ErrInvalidConn = errors.New("invalid connection") ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") ErrNativePassword = errors.New("this user requires mysql native password authentication.") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrUnknownPlugin = errors.New("this authentication plugin is not supported") ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") ErrPktSync = errors.New("commands out of sync. You can't run this command now") ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") ErrBusyBuffer = errors.New("busy buffer") // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn // to trigger a resend. // See https://github.com/go-sql-driver/mysql/pull/302 errBadConnNoWrite = errors.New("bad connection") ) var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) // Logger is used to log critical error messages. type Logger interface { Print(v ...interface{}) } // SetLogger is used to set the logger for critical errors. // The initial logger is os.Stderr. func SetLogger(logger Logger) error { if logger == nil { return errors.New("logger is nil") } errLog = logger return nil } // MySQLError is an error type which represents a single MySQL error type MySQLError struct { Number uint16 Message string } func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/fields.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "database/sql" "reflect" ) func (mf *mysqlField) typeDatabaseName() string { switch mf.fieldType { case fieldTypeBit: return "BIT" case fieldTypeBLOB: if mf.charSet != collations[binaryCollation] { return "TEXT" } return "BLOB" case fieldTypeDate: return "DATE" case fieldTypeDateTime: return "DATETIME" case fieldTypeDecimal: return "DECIMAL" case fieldTypeDouble: return "DOUBLE" case fieldTypeEnum: return "ENUM" case fieldTypeFloat: return "FLOAT" case fieldTypeGeometry: return "GEOMETRY" case fieldTypeInt24: return "MEDIUMINT" case fieldTypeJSON: return "JSON" case fieldTypeLong: return "INT" case fieldTypeLongBLOB: if mf.charSet != collations[binaryCollation] { return "LONGTEXT" } return "LONGBLOB" case fieldTypeLongLong: return "BIGINT" case fieldTypeMediumBLOB: if mf.charSet != collations[binaryCollation] { return "MEDIUMTEXT" } return "MEDIUMBLOB" case fieldTypeNewDate: return "DATE" case fieldTypeNewDecimal: return "DECIMAL" case fieldTypeNULL: return "NULL" case fieldTypeSet: return "SET" case fieldTypeShort: return "SMALLINT" case fieldTypeString: if mf.charSet == collations[binaryCollation] { return "BINARY" } return "CHAR" case fieldTypeTime: return "TIME" case fieldTypeTimestamp: return "TIMESTAMP" case fieldTypeTiny: return "TINYINT" case fieldTypeTinyBLOB: if mf.charSet != collations[binaryCollation] { return "TINYTEXT" } return "TINYBLOB" case fieldTypeVarChar: if mf.charSet == collations[binaryCollation] { return "VARBINARY" } return "VARCHAR" case fieldTypeVarString: if mf.charSet == collations[binaryCollation] { return "VARBINARY" } return "VARCHAR" case fieldTypeYear: return "YEAR" default: return "" } } var ( scanTypeFloat32 = reflect.TypeOf(float32(0)) scanTypeFloat64 = reflect.TypeOf(float64(0)) scanTypeInt8 = reflect.TypeOf(int8(0)) scanTypeInt16 = reflect.TypeOf(int16(0)) scanTypeInt32 = reflect.TypeOf(int32(0)) scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) scanTypeNullTime = reflect.TypeOf(NullTime{}) scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint32 = reflect.TypeOf(uint32(0)) scanTypeUint64 = reflect.TypeOf(uint64(0)) scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) scanTypeUnknown = reflect.TypeOf(new(interface{})) ) type mysqlField struct { tableName string name string length uint32 flags fieldFlag fieldType fieldType decimals byte charSet uint8 } func (mf *mysqlField) scanType() reflect.Type { switch mf.fieldType { case fieldTypeTiny: if mf.flags&flagNotNULL != 0 { if mf.flags&flagUnsigned != 0 { return scanTypeUint8 } return scanTypeInt8 } return scanTypeNullInt case fieldTypeShort, fieldTypeYear: if mf.flags&flagNotNULL != 0 { if mf.flags&flagUnsigned != 0 { return scanTypeUint16 } return scanTypeInt16 } return scanTypeNullInt case fieldTypeInt24, fieldTypeLong: if mf.flags&flagNotNULL != 0 { if mf.flags&flagUnsigned != 0 { return scanTypeUint32 } return scanTypeInt32 } return scanTypeNullInt case fieldTypeLongLong: if mf.flags&flagNotNULL != 0 { if mf.flags&flagUnsigned != 0 { return scanTypeUint64 } return scanTypeInt64 } return scanTypeNullInt case fieldTypeFloat: if mf.flags&flagNotNULL != 0 { return scanTypeFloat32 } return scanTypeNullFloat case fieldTypeDouble: if mf.flags&flagNotNULL != 0 { return scanTypeFloat64 } return scanTypeNullFloat case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, fieldTypeTime: return scanTypeRawBytes case fieldTypeDate, fieldTypeNewDate, fieldTypeTimestamp, fieldTypeDateTime: // NullTime is always returned for more consistent behavior as it can // handle both cases of parseTime regardless if the field is nullable. return scanTypeNullTime default: return scanTypeUnknown } } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/infile.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "fmt" "io" "os" "strings" "sync" ) var ( fileRegister map[string]bool fileRegisterLock sync.RWMutex readerRegister map[string]func() io.Reader readerRegisterLock sync.RWMutex ) // RegisterLocalFile adds the given file to the file whitelist, // so that it can be used by "LOAD DATA LOCAL INFILE ". // Alternatively you can allow the use of all local files with // the DSN parameter 'allowAllFiles=true' // // filePath := "/home/gopher/data.csv" // mysql.RegisterLocalFile(filePath) // err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") // if err != nil { // ... // func RegisterLocalFile(filePath string) { fileRegisterLock.Lock() // lazy map init if fileRegister == nil { fileRegister = make(map[string]bool) } fileRegister[strings.Trim(filePath, `"`)] = true fileRegisterLock.Unlock() } // DeregisterLocalFile removes the given filepath from the whitelist. func DeregisterLocalFile(filePath string) { fileRegisterLock.Lock() delete(fileRegister, strings.Trim(filePath, `"`)) fileRegisterLock.Unlock() } // RegisterReaderHandler registers a handler function which is used // to receive a io.Reader. // The Reader can be used by "LOAD DATA LOCAL INFILE Reader::". // If the handler returns a io.ReadCloser Close() is called when the // request is finished. // // mysql.RegisterReaderHandler("data", func() io.Reader { // var csvReader io.Reader // Some Reader that returns CSV data // ... // Open Reader here // return csvReader // }) // err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") // if err != nil { // ... // func RegisterReaderHandler(name string, handler func() io.Reader) { readerRegisterLock.Lock() // lazy map init if readerRegister == nil { readerRegister = make(map[string]func() io.Reader) } readerRegister[name] = handler readerRegisterLock.Unlock() } // DeregisterReaderHandler removes the ReaderHandler function with // the given name from the registry. func DeregisterReaderHandler(name string) { readerRegisterLock.Lock() delete(readerRegister, name) readerRegisterLock.Unlock() } func deferredClose(err *error, closer io.Closer) { closeErr := closer.Close() if *err == nil { *err = closeErr } } func (mc *mysqlConn) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP if mc.maxWriteSize < packetSize { packetSize = mc.maxWriteSize } if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader // The server might return an an absolute path. See issue #355. name = name[idx+8:] readerRegisterLock.RLock() handler, inMap := readerRegister[name] readerRegisterLock.RUnlock() if inMap { rdr = handler() if rdr != nil { if cl, ok := rdr.(io.Closer); ok { defer deferredClose(&err, cl) } } else { err = fmt.Errorf("Reader '%s' is ", name) } } else { err = fmt.Errorf("Reader '%s' is not registered", name) } } else { // File name = strings.Trim(name, `"`) fileRegisterLock.RLock() fr := fileRegister[name] fileRegisterLock.RUnlock() if mc.cfg.AllowAllFiles || fr { var file *os.File var fi os.FileInfo if file, err = os.Open(name); err == nil { defer deferredClose(&err, file) // get file size if fi, err = file.Stat(); err == nil { rdr = file if fileSize := int(fi.Size()); fileSize < packetSize { packetSize = fileSize } } } } else { err = fmt.Errorf("local file '%s' is not registered", name) } } // send content packets // if packetSize == 0, the Reader contains no data if err == nil && packetSize > 0 { data := make([]byte, 4+packetSize) var n int for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { return ioErr } } } if err == io.EOF { err = nil } } // send empty packet (termination) if data == nil { data = make([]byte, 4) } if ioErr := mc.writePacket(data[:4]); ioErr != nil { return ioErr } // read OK packet if err == nil { _, err = mc.readResultOK() return err } mc.readPacket() return err } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/packets.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "bytes" "crypto/tls" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" "math" "time" ) // Packets documentation: // http://dev.mysql.com/doc/internals/en/client-server-protocol.html // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header data, err := mc.buf.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } errLog.Print(err) mc.Close() return nil, ErrInvalidConn } // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // check packet sync [8 bit] if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul } return nil, ErrPktSync } mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)−1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { errLog.Print(ErrMalformPkt) mc.Close() return nil, ErrInvalidConn } return prevData, nil } // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } errLog.Print(err) mc.Close() return nil, ErrInvalidConn } // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets if prevData == nil { return data, nil } return append(prevData, data...), nil } prevData = append(prevData, data...) } } // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } for { var size int if pktLen >= maxPacketSize { data[0] = 0xff data[1] = 0xff data[2] = 0xff size = maxPacketSize } else { data[0] = byte(pktLen) data[1] = byte(pktLen >> 8) data[2] = byte(pktLen >> 16) size = pktLen } data[3] = mc.sequence // Write packet if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { return err } } n, err := mc.netConn.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { return nil } pktLen -= size data = data[size:] continue } // Handle error if err == nil { // n != len(data) mc.cleanup() errLog.Print(ErrMalformPkt) } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr } if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet return errBadConnNoWrite } mc.cleanup() errLog.Print(err) } return ErrInvalidConn } } /****************************************************************************** * Initialisation Process * ******************************************************************************/ // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake func (mc *mysqlConn) readInitPacket() ([]byte, error) { data, err := mc.readPacket() if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. if err == ErrInvalidConn { return nil, driver.ErrBadConn } return nil, err } if data[0] == iERR { return nil, mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { return nil, fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, ) } // server version [null terminated string] // connection id [4 bytes] pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] cipher := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { return nil, ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { return nil, ErrNoTLS } pos += 2 if len(data) > pos { // character set [1 byte] // status flags [2 bytes] // capability flags (upper 2 bytes) [2 bytes] // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] pos += 1 + 2 + 2 + 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) // // The web documentation is ambiguous about the length. However, // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, // the 13th byte is "\0 byte, terminating the second part of // a scramble". So the second part of the password cipher is // a NULL terminated string that's at least 13 bytes with the // last byte being NULL. // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. cipher = append(cipher, data[pos:pos+12]...) // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise // //if data[len(data)-1] == 0 { // return //} //return ErrMalformPkt // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], cipher) return b[:], nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], cipher) return b[:], nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | clientLongPassword | clientTransactions | clientLocalFiles | clientPluginAuth | clientMultiResults | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } // To enable TLS / SSL if mc.cfg.tls != nil { clientFlags |= clientSSL } if mc.cfg.MultiStatements { clientFlags |= clientMultiStatements } // User Password scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } // Calculate packet length and get buffer with that size data := mc.buf.takeSmallBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // ClientFlags [32 bit] data[4] = byte(clientFlags) data[5] = byte(clientFlags >> 8) data[6] = byte(clientFlags >> 16) data[7] = byte(clientFlags >> 24) // MaxPacketSize [32 bit] (none) data[8] = 0x00 data[9] = 0x00 data[10] = 0x00 data[11] = 0x00 // Charset [1 byte] var found bool data[12], found = collations[mc.cfg.Collation] if !found { // Note possibility for false negatives: // could be triggered although the collation is valid if the // collations map does not contain entries the server supports. return errors.New("unknown collation") } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Send TLS / SSL request packet if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS tlsConn := tls.Client(mc.netConn, mc.cfg.tls) if err := tlsConn.Handshake(); err != nil { return err } mc.netConn = tlsConn mc.buf.nc = tlsConn } // Filler [23 bytes] (all 0x00) pos := 13 for ; pos < 13+23; pos++ { data[pos] = 0 } // User [null terminated string] if len(mc.cfg.User) > 0 { pos += copy(data[pos:], mc.cfg.User) } data[pos] = 0x00 pos++ // ScrambleBuffer [length encoded integer] data[pos] = byte(len(scrambleBuff)) pos += 1 + copy(data[pos+1:], scrambleBuff) // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ } // Assume native client during response pos += copy(data[pos:], "mysql_native_password") data[pos] = 0x00 // Send Auth packet return mc.writePacket(data) } // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password // https://dev.mysql.com/doc/internals/en/old-password-authentication.html // Old password authentication only need and will need 8-byte challenge. scrambleBuff := scrambleOldPassword(cipher[:8], []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add the scrambled password [null terminated string] copy(data[4:], scrambleBuff) data[4+pktLen-1] = 0x00 return mc.writePacket(data) } // Client clear text authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add the clear password [null terminated string] copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 return mc.writePacket(data) } // Native password authentication method // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. scrambleBuff := scramblePassword(cipher[0:20], []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add the scramble copy(data[4:], scrambleBuff) return mc.writePacket(data) } /****************************************************************************** * Command Packets * ******************************************************************************/ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add command byte data[4] = command // Send CMD packet return mc.writePacket(data) } func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 pktLen := 1 + len(arg) data := mc.buf.takeBuffer(pktLen + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add command byte data[4] = command // Add arg copy(data[5:], arg) // Send CMD packet return mc.writePacket(data) } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // Add command byte data[4] = command // Add arg [32 bit] data[5] = byte(arg) data[6] = byte(arg >> 8) data[7] = byte(arg >> 16) data[8] = byte(arg >> 24) // Send CMD packet return mc.writePacket(data) } /****************************************************************************** * Result Packets * ******************************************************************************/ // Returns error if Packet is not an 'Result OK'-Packet func (mc *mysqlConn) readResultOK() ([]byte, error) { data, err := mc.readPacket() if err == nil { // packet indicator switch data[0] { case iOK: return nil, mc.handleOkPacket(data) case iEOF: if len(data) > 1 { pluginEndIndex := bytes.IndexByte(data, 0x00) plugin := string(data[1:pluginEndIndex]) cipher := data[pluginEndIndex+1:] switch plugin { case "mysql_old_password": // using old_passwords return cipher, ErrOldPassword case "mysql_clear_password": // using clear text password return cipher, ErrCleartextPassword case "mysql_native_password": // using mysql default authentication method return cipher, ErrNativePassword default: return cipher, ErrUnknownPlugin } } // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, ErrOldPassword default: // Error otherwise return nil, mc.handleErrorPacket(data) } } return nil, err } // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { data, err := mc.readPacket() if err == nil { switch data[0] { case iOK: return 0, mc.handleOkPacket(data) case iERR: return 0, mc.handleErrorPacket(data) case iLocalInFile: return 0, mc.handleInFileRequest(string(data[1:])) } // column count num, _, n := readLengthEncodedInteger(data) if n-len(data) == 0 { return int(num), nil } return 0, ErrMalformPkt } return 0, err } // Error Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet func (mc *mysqlConn) handleErrorPacket(data []byte) error { if data[0] != iERR { return ErrMalformPkt } // 0xff [1 byte] // Error Number [16 bit uint] errno := binary.LittleEndian.Uint16(data[1:3]) // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { // Oops; we are connected to a read-only connection, and won't be able // to issue any write statements. Since RejectReadOnly is configured, // we throw away this connection hoping this one would have write // permission. This is specifically for a possible race condition // during failover (e.g. on AWS Aurora). See README.md for more. // // We explicitly close the connection before returning // driver.ErrBadConn to ensure that `database/sql` purges this // connection and initiates a new one for next statement next time. mc.Close() return driver.ErrBadConn } pos := 3 // SQL State [optional: # + 5bytes string] if data[3] == 0x23 { //sqlstate := string(data[4 : 4+5]) pos = 9 } // Error Message [string] return &MySQLError{ Number: errno, Message: string(data[pos:]), } } func readStatus(b []byte) statusFlag { return statusFlag(b[0]) | statusFlag(b[1])<<8 } // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { var n, m int // 0x00 [1 byte] // Affected rows [Length Coded Binary] mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) // Insert id [Length Coded Binary] mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) if mc.status&statusMoreResultsExists != 0 { return nil } // warning count [2 bytes] return nil } // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) for i := 0; ; i++ { data, err := mc.readPacket() if err != nil { return nil, err } // EOF Packet if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { if i == count { return columns, nil } return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) } // Catalog pos, err := skipLengthEncodedString(data) if err != nil { return nil, err } // Database [len coded string] n, err := skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n // Table [len coded string] if mc.cfg.ColumnsWithAlias { tableName, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n columns[i].tableName = string(tableName) } else { n, err = skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n } // Original table [len coded string] n, err = skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n // Name [len coded string] name, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } columns[i].name = string(name) pos += n // Original name [len coded string] n, err = skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n // Filler [uint8] pos++ // Charset [charset, collation uint8] columns[i].charSet = data[pos] pos += 2 // Length [uint32] columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) pos += 4 // Field type [uint8] columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) pos += 2 // Decimals [uint8] columns[i].decimals = data[pos] //pos++ // Default value [len coded binary] //if pos < len(data) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } } // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc if rows.rs.done { return io.EOF } data, err := mc.readPacket() if err != nil { return err } // EOF Packet if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } if data[0] == iERR { rows.mc = nil return mc.handleErrorPacket(data) } // RowSet Packet var n int var isNull bool pos := 0 for i := range dest { // Read bytes and convert to string dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) pos += n if err == nil { if !isNull { if !mc.parseTime { continue } else { switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( string(dest[i].([]byte)), mc.cfg.Loc, ) if err == nil { continue } default: continue } } } else { dest[i] = nil continue } } return err // err != nil } return nil } // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read func (mc *mysqlConn) readUntilEOF() error { for { data, err := mc.readPacket() if err != nil { return err } switch data[0] { case iERR: return mc.handleErrorPacket(data) case iEOF: if len(data) == 5 { mc.status = readStatus(data[3:]) } return nil } } } /****************************************************************************** * Prepared Statements * ******************************************************************************/ // Prepare Result Packets // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { // packet indicator [1 byte] if data[0] != iOK { return 0, stmt.mc.handleErrorPacket(data) } // statement id [4 bytes] stmt.id = binary.LittleEndian.Uint32(data[1:5]) // Column count [16 bit uint] columnCount := binary.LittleEndian.Uint16(data[5:7]) // Param count [16 bit uint] stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) // Reserved [8 bit] // Warning count [16 bit uint] return columnCount, nil } return 0, err } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: // 1 byte command // 4 bytes stmtID // 2 bytes paramID const dataOffset = 1 + 4 + 2 // Can not use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) copy(data[4+dataOffset:], arg) for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { if dataOffset+argLen < maxLen { pktLen = dataOffset + argLen } stmt.mc.sequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData // Add stmtID [32 bit] data[5] = byte(stmt.id) data[6] = byte(stmt.id >> 8) data[7] = byte(stmt.id >> 16) data[8] = byte(stmt.id >> 24) // Add paramID [16 bit] data[9] = byte(paramID) data[10] = byte(paramID >> 8) // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue } return err } // Reset Packet Sequence stmt.mc.sequence = 0 return nil } // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", len(args), stmt.paramCount, ) } const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc // Determine threshould dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 } // Reset packet-sequence mc.sequence = 0 var data []byte if len(args) == 0 { data = mc.buf.takeBuffer(minPktLen) } else { data = mc.buf.takeCompleteBuffer() } if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return errBadConnNoWrite } // command [1 byte] data[4] = comStmtExecute // statement_id [4 bytes] data[5] = byte(stmt.id) data[6] = byte(stmt.id >> 8) data[7] = byte(stmt.id >> 16) data[8] = byte(stmt.id >> 24) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] data[10] = 0x01 data[11] = 0x00 data[12] = 0x00 data[13] = 0x00 if len(args) > 0 { pos := minPktLen var nullMask []byte if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here // which makes the required allocation size hard to guess. tmp := make([]byte, pos+maskLen+typesLen) copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] pos += maskLen } else { nullMask = data[pos : pos+maskLen] for i := 0; i < maskLen; i++ { nullMask[i] = 0 } pos += maskLen } // newParameterBoundFlag 1 [1 byte] data[pos] = 0x01 pos++ // type of each parameter [len(args)*2 bytes] paramTypes := data[pos:] pos += len(args) * 2 // value of each parameter [n bytes] paramValues := data[pos:pos] valuesCap := cap(paramValues) for i, arg := range args { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } // cache types and values switch v := arg.(type) { case int64: paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] binary.LittleEndian.PutUint64( paramValues[len(paramValues)-8:], uint64(v), ) } else { paramValues = append(paramValues, uint64ToBytes(uint64(v))..., ) } case float64: paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { paramValues = paramValues[:len(paramValues)+8] binary.LittleEndian.PutUint64( paramValues[len(paramValues)-8:], math.Float64bits(v), ) } else { paramValues = append(paramValues, uint64ToBytes(math.Float64bits(v))..., ) } case bool: paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { paramValues = append(paramValues, 0x01) } else { paramValues = append(paramValues, 0x00) } case []byte: // Common case (non-nil value) first if v != nil { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) paramValues = append(paramValues, v...) } else { if err := stmt.writeCommandLongData(i, v); err != nil { return err } } continue } // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) paramValues = append(paramValues, v...) } else { if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { return err } } case time.Time: paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 var a [64]byte var b = a[:0] if v.IsZero() { b = append(b, "0000-00-00"...) } else { b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) } paramValues = appendLengthEncodedInteger(paramValues, uint64(len(b)), ) paramValues = append(paramValues, b...) default: return fmt.Errorf("can not convert type: %T", arg) } } // Check if param values exceeded the available buffer // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) mc.buf.buf = data } pos += len(paramValues) data = data[:pos] } return mc.writePacket(data) } func (mc *mysqlConn) discardResults() error { for mc.status&statusMoreResultsExists != 0 { resLen, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns if err := mc.readUntilEOF(); err != nil { return err } // rows if err := mc.readUntilEOF(); err != nil { return err } } } return nil } // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil { return err } // packet indicator [1 byte] if data[0] != iOK { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } mc := rows.mc rows.mc = nil // Error otherwise return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] pos := 1 + (len(dest)+7+2)>>3 nullMask := data[1:pos] for i := range dest { // Field is NULL // (byte >> bit-pos) % 2 == 1 if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { dest[i] = nil continue } // Convert to byte-coded string switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) } pos++ continue case fieldTypeShort, fieldTypeYear: if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) } pos += 2 continue case fieldTypeInt24, fieldTypeLong: if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) } pos += 4 continue case fieldTypeLongLong: if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) } else { dest[i] = int64(val) } } else { dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) } pos += 8 continue case fieldTypeFloat: dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue case fieldTypeDouble: dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) pos += 8 continue // Length coded Binary Strings case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: var isNull bool var n int dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) pos += n if err == nil { if !isNull { continue } else { dest[i] = nil continue } } return err case fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] num, isNull, n := readLengthEncodedInteger(data[pos:]) pos += n switch { case isNull: dest[i] = nil continue case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: dstlen = 8 + 1 + decimals default: return fmt.Errorf( "protocol error, illegal decimals value %d", rows.rs.columns[i].decimals, ) } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) case rows.mc.parseTime: dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: dstlen = 19 + 1 + decimals default: return fmt.Errorf( "protocol error, illegal decimals value %d", rows.rs.columns[i].decimals, ) } } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) } if err == nil { pos += int(num) continue } else { return err } // Please report if this happens! default: return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } return nil } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/result.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql type mysqlResult struct { affectedRows int64 insertId int64 } func (res *mysqlResult) LastInsertId() (int64, error) { return res.insertId, nil } func (res *mysqlResult) RowsAffected() (int64, error) { return res.affectedRows, nil } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/rows.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "database/sql/driver" "io" "math" "reflect" ) type resultSet struct { columns []mysqlField columnNames []string done bool } type mysqlRows struct { mc *mysqlConn rs resultSet finish func() } type binaryRows struct { mysqlRows } type textRows struct { mysqlRows } func (rows *mysqlRows) Columns() []string { if rows.rs.columnNames != nil { return rows.rs.columnNames } columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { columns[i] = tableName + "." + rows.rs.columns[i].name } else { columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { columns[i] = rows.rs.columns[i].name } } rows.rs.columnNames = columns return columns } func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { return rows.rs.columns[i].typeDatabaseName() } // func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { // return int64(rows.rs.columns[i].length), true // } func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { return rows.rs.columns[i].flags&flagNotNULL == 0, true } func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { column := rows.rs.columns[i] decimals := int64(column.decimals) switch column.fieldType { case fieldTypeDecimal, fieldTypeNewDecimal: if decimals > 0 { return int64(column.length) - 2, decimals, true } return int64(column.length) - 1, decimals, true case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: return decimals, decimals, true case fieldTypeFloat, fieldTypeDouble: if decimals == 0x1f { return math.MaxInt64, math.MaxInt64, true } return math.MaxInt64, decimals, true } return 0, 0, false } func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { return rows.rs.columns[i].scanType() } func (rows *mysqlRows) Close() (err error) { if f := rows.finish; f != nil { f() rows.finish = nil } mc := rows.mc if mc == nil { return nil } if err := mc.error(); err != nil { return err } // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF() } if err == nil { if err = mc.discardResults(); err != nil { return err } } rows.mc = nil return err } func (rows *mysqlRows) HasNextResultSet() (b bool) { if rows.mc == nil { return false } return rows.mc.status&statusMoreResultsExists != 0 } func (rows *mysqlRows) nextResultSet() (int, error) { if rows.mc == nil { return 0, io.EOF } if err := rows.mc.error(); err != nil { return 0, err } // Remove unread packets from stream if !rows.rs.done { if err := rows.mc.readUntilEOF(); err != nil { return 0, err } rows.rs.done = true } if !rows.HasNextResultSet() { rows.mc = nil return 0, io.EOF } rows.rs = resultSet{} return rows.mc.readResultSetHeaderPacket() } func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { for { resLen, err := rows.nextResultSet() if err != nil { return 0, err } if resLen > 0 { return resLen, nil } rows.rs.done = true } } func (rows *binaryRows) NextResultSet() error { resLen, err := rows.nextNotEmptyResultSet() if err != nil { return err } rows.rs.columns, err = rows.mc.readColumns(resLen) return err } func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { if err := mc.error(); err != nil { return err } // Fetch next row from stream return rows.readRow(dest) } return io.EOF } func (rows *textRows) NextResultSet() (err error) { resLen, err := rows.nextNotEmptyResultSet() if err != nil { return err } rows.rs.columns, err = rows.mc.readColumns(resLen) return err } func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { if err := mc.error(); err != nil { return err } // Fetch next row from stream return rows.readRow(dest) } return io.EOF } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/statement.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "database/sql/driver" "fmt" "io" "reflect" "strconv" ) type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int } func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.closed.IsSet() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. //errLog.Print(ErrInvalidConn) return driver.ErrBadConn } err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) stmt.mc = nil return err } func (stmt *mysqlStmt) NumInput() int { return stmt.paramCount } func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { return nil, stmt.mc.markBadConn(err) } mc := stmt.mc mc.affectedRows = 0 mc.insertId = 0 // Read Result resLen, err := mc.readResultSetHeaderPacket() if err != nil { return nil, err } if resLen > 0 { // Columns if err = mc.readUntilEOF(); err != nil { return nil, err } // Rows if err := mc.readUntilEOF(); err != nil { return nil, err } } if err := mc.discardResults(); err != nil { return nil, err } return &mysqlResult{ affectedRows: int64(mc.affectedRows), insertId: int64(mc.insertId), }, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.query(args) } func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { return nil, stmt.mc.markBadConn(err) } mc := stmt.mc // Read Result resLen, err := mc.readResultSetHeaderPacket() if err != nil { return nil, err } rows := new(binaryRows) if resLen > 0 { rows.mc = mc rows.rs.columns, err = mc.readColumns(resLen) } else { rows.rs.done = true switch err := rows.NextResultSet(); err { case nil, io.EOF: return rows, nil default: return nil, err } } return rows, err } type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver // with _one_ exception. We support uint64 with their high bit and the default // implementation does not. This function should be kept in sync with // database/sql/driver defaultConverter.ConvertValue() except for that // deliberate difference. func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(v) { return v, nil } if vr, ok := v.(driver.Valuer); ok { sv, err := callValuerValue(vr) if err != nil { return nil, err } if !driver.IsValue(sv) { return nil, fmt.Errorf("non-Value type %T returned from Value", sv) } return sv, nil } rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil } else { return c.ConvertValue(rv.Elem().Interface()) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: return int64(rv.Uint()), nil case reflect.Uint64: u64 := rv.Uint() if u64 >= 1<<63 { return strconv.FormatUint(u64, 10), nil } return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil case reflect.Bool: return rv.Bool(), nil case reflect.Slice: ek := rv.Type().Elem().Kind() if ek == reflect.Uint8 { return rv.Bytes(), nil } return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() // callValuerValue returns vr.Value(), with one exception: // If vr.Value is an auto-generated method on a pointer type and the // pointer is nil, it would panic at runtime in the panicwrap // method. Treat it like nil instead. // // This is so people can implement driver.Value on value types and // still use nil pointers to those types to mean nil/NULL, just like // string/*string. // // This is an exact copy of the same-named unexported function from the // database/sql package. func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && rv.IsNil() && rv.Type().Elem().Implements(valuerReflectType) { return nil, nil } return vr.Value() } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/transaction.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql type mysqlTx struct { mc *mysqlConn } func (tx *mysqlTx) Commit() (err error) { if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") tx.mc = nil return } func (tx *mysqlTx) Rollback() (err error) { if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") tx.mc = nil return } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/utils.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "crypto/sha1" "crypto/tls" "database/sql/driver" "encoding/binary" "fmt" "io" "strings" "sync" "sync/atomic" "time" ) var ( tlsConfigLock sync.RWMutex tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs ) // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // // Note: The tls.Config provided to needs to be exclusively owned by the driver after registering. // // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { // log.Fatal(err) // } // if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { // log.Fatal("Failed to append PEM.") // } // clientCert := make([]tls.Certificate, 0, 1) // certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") // if err != nil { // log.Fatal(err) // } // clientCert = append(clientCert, certs) // mysql.RegisterTLSConfig("custom", &tls.Config{ // RootCAs: rootCertPool, // Certificates: clientCert, // }) // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { return fmt.Errorf("key '%s' is reserved", key) } tlsConfigLock.Lock() if tlsConfigRegister == nil { tlsConfigRegister = make(map[string]*tls.Config) } tlsConfigRegister[key] = config tlsConfigLock.Unlock() return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { tlsConfigLock.Lock() if tlsConfigRegister != nil { delete(tlsConfigRegister, key) } tlsConfigLock.Unlock() } func getTLSConfigClone(key string) (config *tls.Config) { tlsConfigLock.RLock() if v, ok := tlsConfigRegister[key]; ok { config = cloneTLSConfig(v) } tlsConfigLock.RUnlock() return } // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { switch input { case "1", "true", "TRUE", "True": return true, true case "0", "false", "FALSE", "False": return false, true } // Not a valid bool value return } /****************************************************************************** * Authentication * ******************************************************************************/ // Encrypt password using 4.1+ method func scramblePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } // stage1Hash = SHA1(password) crypt := sha1.New() crypt.Write(password) stage1 := crypt.Sum(nil) // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) // inner Hash crypt.Reset() crypt.Write(stage1) hash := crypt.Sum(nil) // outer Hash crypt.Reset() crypt.Write(scramble) crypt.Write(hash) scramble = crypt.Sum(nil) // token = scrambleHash XOR stage1Hash for i := range scramble { scramble[i] ^= stage1[i] } return scramble } // Encrypt password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { seed1, seed2 uint32 } const myRndMaxVal = 0x3FFFFFFF // Pseudo random number generator func newMyRnd(seed1, seed2 uint32) *myRnd { return &myRnd{ seed1: seed1 % myRndMaxVal, seed2: seed2 % myRndMaxVal, } } // Tested to be equivalent to MariaDB's floating point variant // http://play.golang.org/p/QHvhd4qved // http://play.golang.org/p/RG0q4ElWDx func (r *myRnd) NextByte() byte { r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal return byte(uint64(r.seed1) * 31 / myRndMaxVal) } // Generate binary hash from byte string using insecure pre 4.1 method func pwHash(password []byte) (result [2]uint32) { var add uint32 = 7 var tmp uint32 result[0] = 1345345333 result[1] = 0x12345671 for _, c := range password { // skip spaces and tabs in password if c == ' ' || c == '\t' { continue } tmp = uint32(c) result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) result[1] += (result[1] << 8) ^ result[0] add += tmp } // Remove sign bit (1<<31)-1) result[0] &= 0x7FFFFFFF result[1] &= 0x7FFFFFFF return } // Encrypt password using insecure pre 4.1 method func scrambleOldPassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } scramble = scramble[:8] hashPw := pwHash(password) hashSc := pwHash(scramble) r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) var out [8]byte for i := range out { out[i] = r.NextByte() + 64 } mask := r.NextByte() for i := range out { out[i] ^= mask } return out[:] } /****************************************************************************** * Time related utils * ******************************************************************************/ // NullTime represents a time.Time that may be NULL. // NullTime implements the Scanner interface so // it can be used as a scan destination: // // var nt NullTime // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) // ... // if nt.Valid { // // use nt.Time // } else { // // NULL value // } // // This NullTime implementation is not driver-specific type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } // Scan implements the Scanner interface. // The value type must be time.Time or string / []byte (formatted time-string), // otherwise Scan fails. func (nt *NullTime) Scan(value interface{}) (err error) { if value == nil { nt.Time, nt.Valid = time.Time{}, false return } switch v := value.(type) { case time.Time: nt.Time, nt.Valid = v, true return case []byte: nt.Time, err = parseDateTime(string(v), time.UTC) nt.Valid = (err == nil) return case string: nt.Time, err = parseDateTime(v, time.UTC) nt.Valid = (err == nil) return } nt.Valid = false return fmt.Errorf("Can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. func (nt NullTime) Value() (driver.Value, error) { if !nt.Valid { return nil, nil } return nt.Time, nil } func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { base := "0000-00-00 00:00:00.0000000" switch len(str) { case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" if str == base[:len(str)] { return } t, err = time.Parse(timeFormat[:len(str)], str) default: err = fmt.Errorf("invalid time string: %s", str) return } // Adjust location if err == nil && loc != time.UTC { y, mo, d := t.Date() h, mi, s := t.Clock() t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil } return } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { switch num { case 0: return time.Time{}, nil case 4: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day 0, 0, 0, 0, loc, ), nil case 7: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds 0, loc, ), nil case 11: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds loc, ), nil } return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } // zeroDateTime is used in formatBinaryDateTime to avoid an allocation // if the DATE or DATETIME has the zero value. // It must never be changed. // The current behavior depends on database/sql copying the result. var zeroDateTime = []byte("0000-00-00 00:00:00.000000") const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed if len(src) == 0 { if justTime { return zeroDateTime[11 : 11+length], nil } return zeroDateTime[:length], nil } var dst []byte // return value var pt, p1, p2, p3 byte // current digit pair var zOffs byte // offset of value in zeroDateTime if justTime { switch length { case 8, // time (can be up to 10 when negative and 100+ hours) 10, 11, 12, 13, 14, 15: // time with fractional seconds default: return nil, fmt.Errorf("illegal TIME length %d", length) } switch len(src) { case 8, 12: default: return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) } // +2 to enable negative time and 100+ hours dst = make([]byte, 0, length+2) if src[0] == 1 { dst = append(dst, '-') } if src[1] != 0 { hour := uint16(src[1])*24 + uint16(src[5]) pt = byte(hour / 100) p1 = byte(hour - 100*uint16(pt)) dst = append(dst, digits01[pt]) } else { p1 = src[5] } zOffs = 11 src = src[6:] } else { switch length { case 10, 19, 21, 22, 23, 24, 25, 26: default: t := "DATE" if length > 10 { t += "TIME" } return nil, fmt.Errorf("illegal %s length %d", t, length) } switch len(src) { case 4, 7, 11: default: t := "DATE" if length > 10 { t += "TIME" } return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) } dst = make([]byte, 0, length) // start with the date year := binary.LittleEndian.Uint16(src[:2]) pt = byte(year / 100) p1 = byte(year - 100*uint16(pt)) p2, p3 = src[2], src[3] dst = append(dst, digits10[pt], digits01[pt], digits10[p1], digits01[p1], '-', digits10[p2], digits01[p2], '-', digits10[p3], digits01[p3], ) if length == 10 { return dst, nil } if len(src) == 4 { return append(dst, zeroDateTime[10:length]...), nil } dst = append(dst, ' ') p1 = src[4] // hour src = src[5:] } // p1 is 2-digit hour, src is after hour p2, p3 = src[0], src[1] dst = append(dst, digits10[p1], digits01[p1], ':', digits10[p2], digits01[p2], ':', digits10[p3], digits01[p3], ) if length <= byte(len(dst)) { return dst, nil } src = src[2:] if len(src) == 0 { return append(dst, zeroDateTime[19:zOffs+length]...), nil } microsecs := binary.LittleEndian.Uint32(src[:4]) p1 = byte(microsecs / 10000) microsecs -= 10000 * uint32(p1) p2 = byte(microsecs / 100) microsecs -= 100 * uint32(p2) p3 = byte(microsecs) switch decimals := zOffs + length - 20; decimals { default: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], digits01[p3], ), nil case 1: return append(dst, '.', digits10[p1], ), nil case 2: return append(dst, '.', digits10[p1], digits01[p1], ), nil case 3: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], ), nil case 4: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], ), nil case 5: return append(dst, '.', digits10[p1], digits01[p1], digits10[p2], digits01[p2], digits10[p3], ), nil } } /****************************************************************************** * Convert from and to bytes * ******************************************************************************/ func uint64ToBytes(n uint64) []byte { return []byte{ byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56), } } func uint64ToString(n uint64) []byte { var a [20]byte i := 20 // U+0030 = 0 // ... // U+0039 = 9 var q uint64 for n >= 10 { i-- q = n / 10 a[i] = uint8(n-q*10) + 0x30 n = q } i-- a[i] = uint8(n) + 0x30 return a[i:] } // treats string value as unsigned integer representation func stringToInt(b []byte) int { val := 0 for i := range b { val *= 10 val += int(b[i] - 0x30) } return val } // returns the string read as a bytes slice, wheter the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } // returns the number of bytes skipped and an error, in case the string is // longer than the input slice func skipLengthEncodedString(b []byte) (int, error) { // Get length num, _, n := readLengthEncodedInteger(b) if num < 1 { return n, nil } n += int(num) // Check data length if len(b) >= n { return n, nil } return n, io.EOF } // returns the number read, whether the value is NULL and the number of bytes read func readLengthEncodedInteger(b []byte) (uint64, bool, int) { // See issue #349 if len(b) == 0 { return 0, true, 1 } switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 // 252: value of following 2 case 0xfc: return uint64(b[1]) | uint64(b[2])<<8, false, 3 // 253: value of following 3 case 0xfd: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 // 254: value of following 8 case 0xfe: return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | uint64(b[7])<<48 | uint64(b[8])<<56, false, 9 } // 0-250: value of first byte return uint64(b[0]), false, 1 } // encodes a uint64 value and appends it to the given bytes slice func appendLengthEncodedInteger(b []byte, n uint64) []byte { switch { case n <= 250: return append(b, byte(n)) case n <= 0xffff: return append(b, 0xfc, byte(n), byte(n>>8)) case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // If cap(buf) is not enough, reallocate new buffer. func reserveBuffer(buf []byte, appendSize int) []byte { newSize := len(buf) + appendSize if cap(buf) < newSize { // Grow buffer exponentially newBuf := make([]byte, len(buf)*2+appendSize) copy(newBuf, buf) buf = newBuf } return buf[:newSize] } // escapeBytesBackslash escapes []byte with backslashes (\) // This escapes the contents of a string (provided as []byte) by adding backslashes before special // characters, and turning others into specific escape sequences, such as // turning newlines into \n and null bytes into \0. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 func escapeBytesBackslash(buf, v []byte) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) for _, c := range v { switch c { case '\x00': buf[pos] = '\\' buf[pos+1] = '0' pos += 2 case '\n': buf[pos] = '\\' buf[pos+1] = 'n' pos += 2 case '\r': buf[pos] = '\\' buf[pos+1] = 'r' pos += 2 case '\x1a': buf[pos] = '\\' buf[pos+1] = 'Z' pos += 2 case '\'': buf[pos] = '\\' buf[pos+1] = '\'' pos += 2 case '"': buf[pos] = '\\' buf[pos+1] = '"' pos += 2 case '\\': buf[pos] = '\\' buf[pos+1] = '\\' pos += 2 default: buf[pos] = c pos++ } } return buf[:pos] } // escapeStringBackslash is similar to escapeBytesBackslash but for string. func escapeStringBackslash(buf []byte, v string) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) for i := 0; i < len(v); i++ { c := v[i] switch c { case '\x00': buf[pos] = '\\' buf[pos+1] = '0' pos += 2 case '\n': buf[pos] = '\\' buf[pos+1] = 'n' pos += 2 case '\r': buf[pos] = '\\' buf[pos+1] = 'r' pos += 2 case '\x1a': buf[pos] = '\\' buf[pos+1] = 'Z' pos += 2 case '\'': buf[pos] = '\\' buf[pos+1] = '\'' pos += 2 case '"': buf[pos] = '\\' buf[pos+1] = '"' pos += 2 case '\\': buf[pos] = '\\' buf[pos+1] = '\\' pos += 2 default: buf[pos] = c pos++ } } return buf[:pos] } // escapeBytesQuotes escapes apostrophes in []byte by doubling them up. // This escapes the contents of a string by doubling up any apostrophes that // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in // effect on the server. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 func escapeBytesQuotes(buf, v []byte) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) for _, c := range v { if c == '\'' { buf[pos] = '\'' buf[pos+1] = '\'' pos += 2 } else { buf[pos] = c pos++ } } return buf[:pos] } // escapeStringQuotes is similar to escapeBytesQuotes but for string. func escapeStringQuotes(buf []byte, v string) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) for i := 0; i < len(v); i++ { c := v[i] if c == '\'' { buf[pos] = '\'' buf[pos+1] = '\'' pos += 2 } else { buf[pos] = c pos++ } } return buf[:pos] } /****************************************************************************** * Sync utils * ******************************************************************************/ // noCopy may be embedded into structs which must not be copied // after the first use. // // See https://github.com/golang/go/issues/8005#issuecomment-190753527 // for details. type noCopy struct{} // Lock is a no-op used by -copylocks checker from `go vet`. func (*noCopy) Lock() {} // atomicBool is a wrapper around uint32 for usage as a boolean value with // atomic access. type atomicBool struct { _noCopy noCopy value uint32 } // IsSet returns wether the current boolean value is true func (ab *atomicBool) IsSet() bool { return atomic.LoadUint32(&ab.value) > 0 } // Set sets the value of the bool regardless of the previous value func (ab *atomicBool) Set(value bool) { if value { atomic.StoreUint32(&ab.value, 1) } else { atomic.StoreUint32(&ab.value, 0) } } // TrySet sets the value of the bool and returns wether the value changed func (ab *atomicBool) TrySet(value bool) bool { if value { return atomic.SwapUint32(&ab.value, 1) == 0 } return atomic.SwapUint32(&ab.value, 0) > 0 } // atomicError is a wrapper for atomically accessed error values type atomicError struct { _noCopy noCopy value atomic.Value } // Set sets the error value regardless of the previous value. // The value must not be nil func (ae *atomicError) Set(value error) { ae.value.Store(value) } // Value returns the current error value func (ae *atomicError) Value() error { if v := ae.value.Load(); v != nil { // this will panic if the value doesn't implement the error interface return v.(error) } return nil } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/utils_go17.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. // +build go1.7 // +build !go1.8 package mysql import "crypto/tls" func cloneTLSConfig(c *tls.Config) *tls.Config { return &tls.Config{ Rand: c.Rand, Time: c.Time, Certificates: c.Certificates, NameToCertificate: c.NameToCertificate, GetCertificate: c.GetCertificate, RootCAs: c.RootCAs, NextProtos: c.NextProtos, ServerName: c.ServerName, ClientAuth: c.ClientAuth, ClientCAs: c.ClientCAs, InsecureSkipVerify: c.InsecureSkipVerify, CipherSuites: c.CipherSuites, PreferServerCipherSuites: c.PreferServerCipherSuites, SessionTicketsDisabled: c.SessionTicketsDisabled, SessionTicketKey: c.SessionTicketKey, ClientSessionCache: c.ClientSessionCache, MinVersion: c.MinVersion, MaxVersion: c.MaxVersion, CurvePreferences: c.CurvePreferences, DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, Renegotiation: c.Renegotiation, } } ================================================ FILE: vendor/github.com/go-sql-driver/mysql/utils_go18.go ================================================ // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. // +build go1.8 package mysql import ( "crypto/tls" "database/sql" "database/sql/driver" "errors" "fmt" ) func cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { dargs := make([]driver.Value, len(named)) for n, param := range named { if len(param.Name) > 0 { // TODO: support the use of Named Parameters #561 return nil, errors.New("mysql: driver does not support the use of Named Parameters") } dargs[n] = param.Value } return dargs, nil } func mapIsolationLevel(level driver.IsolationLevel) (string, error) { switch sql.IsolationLevel(level) { case sql.LevelRepeatableRead: return "REPEATABLE READ", nil case sql.LevelReadCommitted: return "READ COMMITTED", nil case sql.LevelReadUncommitted: return "READ UNCOMMITTED", nil case sql.LevelSerializable: return "SERIALIZABLE", nil default: return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) } } ================================================ FILE: vendor/github.com/google/wire/AUTHORS ================================================ # This is the official list of Wire authors for copyright purposes. # This file is distinct from the CONTRIBUTORS files. # See the latter for an explanation. # Names should be added to this file as one of # Organization's name # Individual's name # Individual's name # See CONTRIBUTORS for the meaning of multiple email addresses. # Please keep the list sorted. Google LLC ktr Oleg Kovalov Yoichiro Shimizu Zachary Romero ================================================ FILE: vendor/github.com/google/wire/CODE_OF_CONDUCT.md ================================================ # Code of Conduct This project is covered under the [Go Code of Conduct][]. In summary: - Treat everyone with respect and kindness. - Be thoughtful in how you communicate. - Don’t be destructive or inflammatory. - If you encounter an issue, please mail conduct@golang.org. [Go Code of Conduct]: https://golang.org/conduct ================================================ FILE: vendor/github.com/google/wire/CONTRIBUTING.md ================================================ # How to Contribute We would love to accept your patches and contributions to this project. Here is how you can help. ## Filing issues Filing issues is an important way you can contribute to the Wire Project. We want your feedback on things like bugs, desired API changes, or just anything that isn't working for you. ### Bugs If your issue is a bug, open one [here](https://github.com/google/wire/issues/new). The easiest way to file an issue with all the right information is to run `go bug`. `go bug` will print out a handy template of questions and system information that will help us get to the root of the issue quicker. ### Changes Unlike the core Go project, we do not have a formal proposal process for changes. If you have a change you would like to see in Wire, please file an issue with the necessary details. ### Triaging The Go Cloud team triages issues at least every two weeks, but usually within two business days. Bugs or feature requests are either placed into a **Sprint** milestone which means the issue is intended to be worked on. Issues that we would like to address but do not have time for are placed into the [Unplanned][] milestone. [Unplanned]: https://github.com/google/wire/milestone/1 ## Contributing Code We love accepting contributions! If your change is minor, please feel free submit a [pull request](https://help.github.com/articles/about-pull-requests/). If your change is larger, or adds a feature, please file an issue beforehand so that we can discuss the change. You're welcome to file an implementation pull request immediately as well, although we generally lean towards discussing the change and then reviewing the implementation separately. ### Finding something to work on If you want to write some code, but don't know where to start or what you might want to do, take a look at our [Unplanned][] milestone. This is where you can find issues we would like to address but can't currently find time for. See if any of the latest ones look interesting! If you need help before you can start work, you can comment on the issue and we will try to help as best we can. ### Contributor License Agreement Contributions to this project can only be made by those who have signed Google's Contributor License Agreement. You (or your employer) retain the copyright to your contribution, this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. As a personal contributor, you only need to sign the Google CLA once across all Google projects. If you've already signed the CLA, there is no need to do it again. If you are submitting code on behalf of your employer, there's [a separate corporate CLA that your employer manages for you](https://opensource.google.com/docs/cla/#external-contributors). ## Making a pull request * Follow the normal [pull request flow](https://help.github.com/articles/creating-a-pull-request/) * Build your changes using Go 1.11 with Go modules enabled. Wire's continuous integration uses Go modules in order to ensure [reproducible builds](https://research.swtch.com/vgo-repro). * Test your changes using `go test ./...`. Please add tests that show the change does what it says it does, even if there wasn't a test in the first place. * Feel free to make as many commits as you want; we will squash them all into a single commit before merging your change. * Check the diffs, write a useful description (including something like `Fixes #123` if it's fixing a bug) and send the PR out. * [Travis CI](http://travis-ci.com) will run tests against the PR. This should happen within 10 minutes or so. If a test fails, go back to the coding stage and try to fix the test and push the same branch again. You won't need to make a new pull request, the changes will be rolled directly into the PR you already opened. Wait for Travis again. There is no need to assign a reviewer to the PR, the project team will assign someone for review during the standard [triage](#triaging) process. ## Code review All submissions, including submissions by project members, require review. It is almost never the case that a pull request is accepted without some changes requested, so please do not be offended! When you have finished making requested changes to your pull request, please make a comment containing "PTAL" (Please Take Another Look) on your pull request. GitHub notifications can be noisy, and it is unfortunately easy for things to be lost in the shuffle. Once your PR is approved (hooray!) the reviewer will squash your commits into a single commit, and then merge the commit onto the Wire master branch. Thank you! ## Github code review workflow conventions (For project members and frequent contributors.) As a contributor: - Try hard to make each Pull Request as small and focused as possible. In particular, this means that if a reviewer asks you to do something that is beyond the scope of the Pull Request, the best practice is to file another issue and reference it from the Pull Request rather than just adding more commits to the existing PR. - Adding someone as a Reviewer means "please feel free to look and comment"; the review is optional. Choose as many Reviewers as you'd like. - Adding someone as an Assignee means that the Pull Request should not be submitted until they approve. If you choose multiple Assignees, wait until all of them approve. It is fine to ask someone if they are OK with being removed as an Assignee. - Note that if you don't select any assignees, ContributeBot will turn all of your Reviewers into Assignees. - Make as many commits as you want locally, but try not to push them to Github until you've addressed comments; this allows the email notification about the push to be a signal to reviewers that the PR is ready to be looked at again. - When there may be confusion about what should happen next for a PR, be explicit; add a "PTAL" comment if it is ready for review again, or a "Please hold off on reviewing for now" if you are still working on addressing comments. - "Resolve" comments that you are sure you've addressed; let your reviewers resolve ones that you're not sure about. - Do not use `git push --force`; this can cause comments from your reviewers that are associated with a specific commit to be lost. This implies that once you've sent a Pull Request, you should use `git merge` instead of `git rebase` to incorporate commits from the master branch. As a reviewer: - Be timely in your review process, especially if you are an Assignee. - Try to use `Start a Review` instead of single comments, to reduce email spam. - "Resolve" your own comments if they have been addressed. - If you want your review to be blocking, and are not currently an Assignee, add yourself as an Assignee. When squashing-and-merging: - Ensure that **all** of the Assignees have approved. - Do a final review of the one-line PR summary, ensuring that it accurately describes the change. - Delete the automatically added commit lines; these are generally not interesting and make commit history harder to read. ================================================ FILE: vendor/github.com/google/wire/CONTRIBUTORS ================================================ # This is the official list of people who can contribute # (and typically have contributed) code to the Wire repository. # The AUTHORS file lists the copyright holders; this file # lists people. For example, Google employees are listed here # but not in AUTHORS, because Google holds the copyright. # # Names should be added to this file only after verifying that # the individual or the individual's organization has agreed to # the appropriate Contributor License Agreement, found here: # # http://code.google.com/legal/individual-cla-v1.0.html # http://code.google.com/legal/corporate-cla-v1.0.html # # The agreement for individuals can be filled out on the web. # # When adding J Random Contributor's name to this file, # either J's name or J's organization's name should be # added to the AUTHORS file, depending on whether the # individual or corporate CLA was used. # Names should be added to this file like so: # Individual's name # Individual's name # # An entry with multiple email addresses specifies that the # first address should be used in the submit logs and # that the other addresses should be recognized as the # same person when interacting with Git. # Please keep the list sorted. Chris Lewis Christina Austin <4240737+clausti@users.noreply.github.com> Eno Compton Issac Trotts ktr Oleg Kovalov Robert van Gent Ross Light Tuo Shan Yoichiro Shimizu Zachary Romero ================================================ FILE: vendor/github.com/google/wire/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: vendor/github.com/google/wire/README.md ================================================ # Wire: Automated Initialization in Go [![Build Status](https://travis-ci.com/google/wire.svg?branch=master)][travis] [![godoc](https://godoc.org/github.com/google/wire?status.svg)][godoc] [![Coverage Status](https://coveralls.io/repos/github/google/wire/badge.svg?branch=master)](https://coveralls.io/github/google/wire?branch=master) Wire is a code generation tool that automates connecting components using [dependency injection][]. Dependencies between components are represented in Wire as function parameters, encouraging explicit initialization instead of global variables. Because Wire operates without runtime state or reflection, code written to be used with Wire is useful even for hand-written initialization. For an overview, see the [introductory blog post][]. [dependency injection]: https://en.wikipedia.org/wiki/Dependency_injection [introductory blog post]: https://blog.golang.org/wire [godoc]: https://godoc.org/github.com/google/wire [travis]: https://travis-ci.com/google/wire ## Installing Install Wire by running: ```shell go get github.com/google/wire/cmd/wire ``` and ensuring that `$GOPATH/bin` is added to your `$PATH`. ## Documentation - [Tutorial][] - [User Guide][] - [Best Practices][] - [FAQ][] [Tutorial]: ./_tutorial/README.md [Best Practices]: ./docs/best-practices.md [FAQ]: ./docs/faq.md [User Guide]: ./docs/guide.md ## Project status **This project is in alpha and is not yet suitable for production.** While in alpha, the API is subject to breaking changes. ## Community You can contact us on the [go-cloud mailing list][]. This project is covered by the Go [Code of Conduct][]. [Code of Conduct]: ./CODE_OF_CONDUCT.md [go-cloud mailing list]: https://groups.google.com/forum/#!forum/go-cloud ================================================ FILE: vendor/github.com/google/wire/go.mod ================================================ module github.com/google/wire require ( github.com/google/go-cmp v0.2.0 github.com/pmezard/go-difflib v1.0.0 golang.org/x/tools v0.0.0-20181017214349-06f26fdaaa28 ) ================================================ FILE: vendor/github.com/google/wire/go.sum ================================================ github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= golang.org/x/tools v0.0.0-20181017214349-06f26fdaaa28 h1:vnbqcYKfOxPnXXUlBo7t+R4pVIh0wInyOSNxih1S9Dc= golang.org/x/tools v0.0.0-20181017214349-06f26fdaaa28/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= ================================================ FILE: vendor/github.com/google/wire/wire.go ================================================ // Copyright 2018 The Wire Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package wire contains directives for Wire code generation. // For an overview of working with Wire, see the user guide at // https://github.com/google/wire/blob/master/docs/guide.md // // The directives in this package are used as input to the Wire code generation // tool. The entry point of Wire's analysis are injector functions: function // templates denoted by only containing a call to Build. The arguments to Build // describes a set of providers and the Wire code generation tool builds a // directed acylic graph of the providers' output types. The generated code will // fill in the function template by using the providers from the provider set to // instantiate any needed types. package wire // ProviderSet is a marker type that collects a group of providers. type ProviderSet struct{} // NewSet creates a new provider set that includes the providers in its // arguments. Each argument is a function value, a struct (zero) value, a // provider set, a call to Bind, a call to Value, or a call to InterfaceValue. // // Passing a function value to NewSet declares that the function's first // return value type will be provided by calling the function. The arguments // to the function will come from the providers for their types. As such, all // the parameters must be of non-identical types. The function may optionally // return an error as its last return value and a cleanup function as the // second return value. A cleanup function must be of type func() and is // guaranteed to be called before the cleanup function of any of the // provider's inputs. If any provider returns an error, the injector function // will call all the appropriate cleanup functions and return the error from // the injector function. // // Passing a struct value of type S to NewSet declares that both S and *S will // be provided by creating a new value of the appropriate type by filling in // each field of S using the provider of the field's type. // // Passing a ProviderSet to NewSet is the same as if the set's contents // were passed as arguments to NewSet directly. // // The behavior of passing the result of a call to other functions in this // package are described in their respective doc comments. func NewSet(...interface{}) ProviderSet { return ProviderSet{} } // Build is placed in the body of an injector function template to declare the // providers to use. The Wire code generation tool will fill in an // implementation of the function. The arguments to Build are interpreted the // same as NewSet: they determine the provider set presented to Wire's // dependency graph. Build returns an error message that can be sent to a call // to panic(). // // The parameters of the injector function are used as inputs in the dependency // graph. // // Similar to provider functions passed into NewSet, the first return value is // the output of the injector function, the optional second return value is a // cleanup function, and the optional last return value is an error. If any of // the provider functions in the injector function's provider set return errors // or cleanup functions, the corresponding return value must be present in the // injector function template. // // Examples: // // func injector(ctx context.Context) (*sql.DB, error) { // wire.Build(otherpkg.FooSet, myProviderFunc) // return nil, nil // } // // func injector(ctx context.Context) (*sql.DB, error) { // panic(wire.Build(otherpkg.FooSet, myProviderFunc)) // } func Build(...interface{}) string { return "implementation not generated, run wire" } // A Binding maps an interface to a concrete type. type Binding struct{} // Bind declares that a concrete type should be used to satisfy a // dependency on the type of iface, which must be a pointer to an // interface type. // // Example: // // type Fooer interface { // Foo() // } // // type MyFoo struct{} // // func (MyFoo) Foo() {} // // var MySet = wire.NewSet( // MyFoo{}, // wire.Bind(new(Fooer), new(MyFoo))) func Bind(iface, to interface{}) Binding { return Binding{} } // A ProvidedValue is an expression that is copied to the generated injector. type ProvidedValue struct{} // Value binds an expression to provide the type of the expression. // The expression may not be an interface value; use InterfaceValue for that. // // Example: // // var MySet = wire.NewSet(wire.Value([]string(nil))) func Value(interface{}) ProvidedValue { return ProvidedValue{} } // InterfaceValue binds an expression to provide a specific interface type. // // Example: // // var MySet = wire.NewSet(wire.InterfaceValue(new(io.Reader), os.Stdin)) func InterfaceValue(typ interface{}, x interface{}) ProvidedValue { return ProvidedValue{} } ================================================ FILE: vendor/github.com/gorilla/context/LICENSE ================================================ Copyright (c) 2012 Rodrigo Moraes. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: vendor/github.com/gorilla/context/README.md ================================================ context ======= [![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context) gorilla/context is a general purpose registry for global request variables. > Note: gorilla/context, having been born well before `context.Context` existed, does not play well > with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`. Read the full documentation here: http://www.gorillatoolkit.org/pkg/context ================================================ FILE: vendor/github.com/gorilla/context/context.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package context import ( "net/http" "sync" "time" ) var ( mutex sync.RWMutex data = make(map[*http.Request]map[interface{}]interface{}) datat = make(map[*http.Request]int64) ) // Set stores a value for a given key in a given request. func Set(r *http.Request, key, val interface{}) { mutex.Lock() if data[r] == nil { data[r] = make(map[interface{}]interface{}) datat[r] = time.Now().Unix() } data[r][key] = val mutex.Unlock() } // Get returns a value stored for a given key in a given request. func Get(r *http.Request, key interface{}) interface{} { mutex.RLock() if ctx := data[r]; ctx != nil { value := ctx[key] mutex.RUnlock() return value } mutex.RUnlock() return nil } // GetOk returns stored value and presence state like multi-value return of map access. func GetOk(r *http.Request, key interface{}) (interface{}, bool) { mutex.RLock() if _, ok := data[r]; ok { value, ok := data[r][key] mutex.RUnlock() return value, ok } mutex.RUnlock() return nil, false } // GetAll returns all stored values for the request as a map. Nil is returned for invalid requests. func GetAll(r *http.Request) map[interface{}]interface{} { mutex.RLock() if context, ok := data[r]; ok { result := make(map[interface{}]interface{}, len(context)) for k, v := range context { result[k] = v } mutex.RUnlock() return result } mutex.RUnlock() return nil } // GetAllOk returns all stored values for the request as a map and a boolean value that indicates if // the request was registered. func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) { mutex.RLock() context, ok := data[r] result := make(map[interface{}]interface{}, len(context)) for k, v := range context { result[k] = v } mutex.RUnlock() return result, ok } // Delete removes a value stored for a given key in a given request. func Delete(r *http.Request, key interface{}) { mutex.Lock() if data[r] != nil { delete(data[r], key) } mutex.Unlock() } // Clear removes all values stored for a given request. // // This is usually called by a handler wrapper to clean up request // variables at the end of a request lifetime. See ClearHandler(). func Clear(r *http.Request) { mutex.Lock() clear(r) mutex.Unlock() } // clear is Clear without the lock. func clear(r *http.Request) { delete(data, r) delete(datat, r) } // Purge removes request data stored for longer than maxAge, in seconds. // It returns the amount of requests removed. // // If maxAge <= 0, all request data is removed. // // This is only used for sanity check: in case context cleaning was not // properly set some request data can be kept forever, consuming an increasing // amount of memory. In case this is detected, Purge() must be called // periodically until the problem is fixed. func Purge(maxAge int) int { mutex.Lock() count := 0 if maxAge <= 0 { count = len(data) data = make(map[*http.Request]map[interface{}]interface{}) datat = make(map[*http.Request]int64) } else { min := time.Now().Unix() - int64(maxAge) for r := range data { if datat[r] < min { clear(r) count++ } } } mutex.Unlock() return count } // ClearHandler wraps an http.Handler and clears request values at the end // of a request lifetime. func ClearHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer Clear(r) h.ServeHTTP(w, r) }) } ================================================ FILE: vendor/github.com/gorilla/context/doc.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* Package context stores values shared during a request lifetime. Note: gorilla/context, having been born well before `context.Context` existed, does not play well > with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`. For example, a router can set variables extracted from the URL and later application handlers can access those values, or it can be used to store sessions values to be saved at the end of a request. There are several others common uses. The idea was posted by Brad Fitzpatrick to the go-nuts mailing list: http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53 Here's the basic usage: first define the keys that you will need. The key type is interface{} so a key can be of any type that supports equality. Here we define a key using a custom int type to avoid name collisions: package foo import ( "github.com/gorilla/context" ) type key int const MyKey key = 0 Then set a variable. Variables are bound to an http.Request object, so you need a request instance to set a value: context.Set(r, MyKey, "bar") The application can later access the variable using the same key you provided: func MyHandler(w http.ResponseWriter, r *http.Request) { // val is "bar". val := context.Get(r, foo.MyKey) // returns ("bar", true) val, ok := context.GetOk(r, foo.MyKey) // ... } And that's all about the basic usage. We discuss some other ideas below. Any type can be stored in the context. To enforce a given type, make the key private and wrap Get() and Set() to accept and return values of a specific type: type key int const mykey key = 0 // GetMyKey returns a value for this package from the request values. func GetMyKey(r *http.Request) SomeType { if rv := context.Get(r, mykey); rv != nil { return rv.(SomeType) } return nil } // SetMyKey sets a value for this package in the request values. func SetMyKey(r *http.Request, val SomeType) { context.Set(r, mykey, val) } Variables must be cleared at the end of a request, to remove all values that were stored. This can be done in an http.Handler, after a request was served. Just call Clear() passing the request: context.Clear(r) ...or use ClearHandler(), which conveniently wraps an http.Handler to clear variables at the end of a request lifetime. The Routers from the packages gorilla/mux and gorilla/pat call Clear() so if you are using either of them you don't need to clear the context manually. */ package context ================================================ FILE: vendor/github.com/gorilla/mux/ISSUE_TEMPLATE.md ================================================ **What version of Go are you running?** (Paste the output of `go version`) **What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`) **Describe your problem** (and what you have tried so far) **Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it) ================================================ FILE: vendor/github.com/gorilla/mux/LICENSE ================================================ Copyright (c) 2012 Rodrigo Moraes. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: vendor/github.com/gorilla/mux/README.md ================================================ # gorilla/mux [![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux) [![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux) [![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge) ![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png) http://www.gorillatoolkit.org/pkg/mux Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to their respective handler. The name mux stands for "HTTP request multiplexer". Like the standard `http.ServeMux`, `mux.Router` matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are: * It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`. * Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers. * URL hosts, paths and query values can have variables with an optional regular expression. * Registered URLs can be built, or "reversed", which helps maintaining references to resources. * Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching. --- * [Install](#install) * [Examples](#examples) * [Matching Routes](#matching-routes) * [Static Files](#static-files) * [Registered URLs](#registered-urls) * [Walking Routes](#walking-routes) * [Graceful Shutdown](#graceful-shutdown) * [Middleware](#middleware) * [Testing Handlers](#testing-handlers) * [Full Example](#full-example) --- ## Install With a [correctly configured](https://golang.org/doc/install#testing) Go toolchain: ```sh go get -u github.com/gorilla/mux ``` ## Examples Let's start registering a couple of URL paths and handlers: ```go func main() { r := mux.NewRouter() r.HandleFunc("/", HomeHandler) r.HandleFunc("/products", ProductsHandler) r.HandleFunc("/articles", ArticlesHandler) http.Handle("/", r) } ``` Here we register three routes mapping URL paths to handlers. This is equivalent to how `http.HandleFunc()` works: if an incoming request URL matches one of the paths, the corresponding handler is called passing (`http.ResponseWriter`, `*http.Request`) as parameters. Paths can have variables. They are defined using the format `{name}` or `{name:pattern}`. If a regular expression pattern is not defined, the matched variable will be anything until the next slash. For example: ```go r := mux.NewRouter() r.HandleFunc("/products/{key}", ProductHandler) r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler) r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) ``` The names are used to create a map of route variables which can be retrieved calling `mux.Vars()`: ```go func ArticlesCategoryHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Category: %v\n", vars["category"]) } ``` And this is all you need to know about the basic usage. More advanced options are explained below. ### Matching Routes Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables: ```go r := mux.NewRouter() // Only matches if domain is "www.example.com". r.Host("www.example.com") // Matches a dynamic subdomain. r.Host("{subdomain:[a-z]+}.domain.com") ``` There are several other matchers that can be added. To match path prefixes: ```go r.PathPrefix("/products/") ``` ...or HTTP methods: ```go r.Methods("GET", "POST") ``` ...or URL schemes: ```go r.Schemes("https") ``` ...or header values: ```go r.Headers("X-Requested-With", "XMLHttpRequest") ``` ...or query values: ```go r.Queries("key", "value") ``` ...or to use a custom matcher function: ```go r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool { return r.ProtoMajor == 0 }) ``` ...and finally, it is possible to combine several matchers in a single route: ```go r.HandleFunc("/products", ProductsHandler). Host("www.example.com"). Methods("GET"). Schemes("http") ``` Routes are tested in the order they were added to the router. If two routes match, the first one wins: ```go r := mux.NewRouter() r.HandleFunc("/specific", specificHandler) r.PathPrefix("/").Handler(catchAllHandler) ``` Setting the same matching conditions again and again can be boring, so we have a way to group several routes that share the same requirements. We call it "subrouting". For example, let's say we have several URLs that should only match when the host is `www.example.com`. Create a route for that host and get a "subrouter" from it: ```go r := mux.NewRouter() s := r.Host("www.example.com").Subrouter() ``` Then register routes in the subrouter: ```go s.HandleFunc("/products/", ProductsHandler) s.HandleFunc("/products/{key}", ProductHandler) s.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) ``` The three URL paths we registered above will only be tested if the domain is `www.example.com`, because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route. Subrouters can be used to create domain or path "namespaces": you define subrouters in a central place and then parts of the app can register its paths relatively to a given subrouter. There's one more thing about subroutes. When a subrouter has a path prefix, the inner routes use it as base for their paths: ```go r := mux.NewRouter() s := r.PathPrefix("/products").Subrouter() // "/products/" s.HandleFunc("/", ProductsHandler) // "/products/{key}/" s.HandleFunc("/{key}/", ProductHandler) // "/products/{key}/details" s.HandleFunc("/{key}/details", ProductDetailsHandler) ``` ### Static Files Note that the path provided to `PathPrefix()` represents a "wildcard": calling `PathPrefix("/static/").Handler(...)` means that the handler will be passed any request that matches "/static/\*". This makes it easy to serve static files with mux: ```go func main() { var dir string flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir") flag.Parse() r := mux.NewRouter() // This will serve files under http://localhost:8000/static/ r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir)))) srv := &http.Server{ Handler: r, Addr: "127.0.0.1:8000", // Good practice: enforce timeouts for servers you create! WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, } log.Fatal(srv.ListenAndServe()) } ``` ### Registered URLs Now let's see how to build registered URLs. Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling `Name()` on a route. For example: ```go r := mux.NewRouter() r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). Name("article") ``` To build a URL, get the route and call the `URL()` method, passing a sequence of key/value pairs for the route variables. For the previous route, we would do: ```go url, err := r.Get("article").URL("category", "technology", "id", "42") ``` ...and the result will be a `url.URL` with the following path: ``` "/articles/technology/42" ``` This also works for host and query value variables: ```go r := mux.NewRouter() r.Host("{subdomain}.domain.com"). Path("/articles/{category}/{id:[0-9]+}"). Queries("filter", "{filter}"). HandlerFunc(ArticleHandler). Name("article") // url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla" url, err := r.Get("article").URL("subdomain", "news", "category", "technology", "id", "42", "filter", "gorilla") ``` All variables defined in the route are required, and their values must conform to the corresponding patterns. These requirements guarantee that a generated URL will always match a registered route -- the only exception is for explicitly defined "build-only" routes which never match. Regex support also exists for matching Headers within a route. For example, we could do: ```go r.HeadersRegexp("Content-Type", "application/(text|json)") ``` ...and the route will match both requests with a Content-Type of `application/json` as well as `application/text` There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do: ```go // "http://news.domain.com/" host, err := r.Get("article").URLHost("subdomain", "news") // "/articles/technology/42" path, err := r.Get("article").URLPath("category", "technology", "id", "42") ``` And if you use subrouters, host and path defined separately can be built as well: ```go r := mux.NewRouter() s := r.Host("{subdomain}.domain.com").Subrouter() s.Path("/articles/{category}/{id:[0-9]+}"). HandlerFunc(ArticleHandler). Name("article") // "http://news.domain.com/articles/technology/42" url, err := r.Get("article").URL("subdomain", "news", "category", "technology", "id", "42") ``` ### Walking Routes The `Walk` function on `mux.Router` can be used to visit all of the routes that are registered on a router. For example, the following prints all of the registered routes: ```go package main import ( "fmt" "net/http" "strings" "github.com/gorilla/mux" ) func handler(w http.ResponseWriter, r *http.Request) { return } func main() { r := mux.NewRouter() r.HandleFunc("/", handler) r.HandleFunc("/products", handler).Methods("POST") r.HandleFunc("/articles", handler).Methods("GET") r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT") r.HandleFunc("/authors", handler).Queries("surname", "{surname}") err := r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { pathTemplate, err := route.GetPathTemplate() if err == nil { fmt.Println("ROUTE:", pathTemplate) } pathRegexp, err := route.GetPathRegexp() if err == nil { fmt.Println("Path regexp:", pathRegexp) } queriesTemplates, err := route.GetQueriesTemplates() if err == nil { fmt.Println("Queries templates:", strings.Join(queriesTemplates, ",")) } queriesRegexps, err := route.GetQueriesRegexp() if err == nil { fmt.Println("Queries regexps:", strings.Join(queriesRegexps, ",")) } methods, err := route.GetMethods() if err == nil { fmt.Println("Methods:", strings.Join(methods, ",")) } fmt.Println() return nil }) if err != nil { fmt.Println(err) } http.Handle("/", r) } ``` ### Graceful Shutdown Go 1.8 introduced the ability to [gracefully shutdown](https://golang.org/doc/go1.8#http_shutdown) a `*http.Server`. Here's how to do that alongside `mux`: ```go package main import ( "context" "flag" "log" "net/http" "os" "os/signal" "time" "github.com/gorilla/mux" ) func main() { var wait time.Duration flag.DurationVar(&wait, "graceful-timeout", time.Second * 15, "the duration for which the server gracefully wait for existing connections to finish - e.g. 15s or 1m") flag.Parse() r := mux.NewRouter() // Add your routes as needed srv := &http.Server{ Addr: "0.0.0.0:8080", // Good practice to set timeouts to avoid Slowloris attacks. WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, Handler: r, // Pass our instance of gorilla/mux in. } // Run our server in a goroutine so that it doesn't block. go func() { if err := srv.ListenAndServe(); err != nil { log.Println(err) } }() c := make(chan os.Signal, 1) // We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C) // SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught. signal.Notify(c, os.Interrupt) // Block until we receive our signal. <-c // Create a deadline to wait for. ctx, cancel := context.WithTimeout(context.Background(), wait) defer cancel() // Doesn't block if no connections, but will otherwise wait // until the timeout deadline. srv.Shutdown(ctx) // Optionally, you could run srv.Shutdown in a goroutine and block on // <-ctx.Done() if your application should wait for other services // to finalize based on context cancellation. log.Println("shutting down") os.Exit(0) } ``` ### Middleware Mux supports the addition of middlewares to a [Router](https://godoc.org/github.com/gorilla/mux#Router), which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or `ResponseWriter` hijacking. Mux middlewares are defined using the de facto standard type: ```go type MiddlewareFunc func(http.Handler) http.Handler ``` Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc. This takes advantage of closures being able access variables from the context where they are created, while retaining the signature enforced by the receivers. A very basic middleware which logs the URI of the request being handled could be written as: ```go func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Do stuff here log.Println(r.RequestURI) // Call the next handler, which can be another middleware in the chain, or the final handler. next.ServeHTTP(w, r) }) } ``` Middlewares can be added to a router using `Router.Use()`: ```go r := mux.NewRouter() r.HandleFunc("/", handler) r.Use(loggingMiddleware) ``` A more complex authentication middleware, which maps session token to users, could be written as: ```go // Define our struct type authenticationMiddleware struct { tokenUsers map[string]string } // Initialize it somewhere func (amw *authenticationMiddleware) Populate() { amw.tokenUsers["00000000"] = "user0" amw.tokenUsers["aaaaaaaa"] = "userA" amw.tokenUsers["05f717e5"] = "randomUser" amw.tokenUsers["deadbeef"] = "user0" } // Middleware function, which will be called for each request func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("X-Session-Token") if user, found := amw.tokenUsers[token]; found { // We found the token in our map log.Printf("Authenticated user %s\n", user) // Pass down the request to the next middleware (or final handler) next.ServeHTTP(w, r) } else { // Write an error and stop the handler chain http.Error(w, "Forbidden", http.StatusForbidden) } }) } ``` ```go r := mux.NewRouter() r.HandleFunc("/", handler) amw := authenticationMiddleware{} amw.Populate() r.Use(amw.Middleware) ``` Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it. ### Testing Handlers Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_. First, our simple HTTP handler: ```go // endpoints.go package main func HealthCheckHandler(w http.ResponseWriter, r *http.Request) { // A very simple health check. w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") // In the future we could report back on the status of our DB, or our cache // (e.g. Redis) by performing a simple PING, and include them in the response. io.WriteString(w, `{"alive": true}`) } func main() { r := mux.NewRouter() r.HandleFunc("/health", HealthCheckHandler) log.Fatal(http.ListenAndServe("localhost:8080", r)) } ``` Our test code: ```go // endpoints_test.go package main import ( "net/http" "net/http/httptest" "testing" ) func TestHealthCheckHandler(t *testing.T) { // Create a request to pass to our handler. We don't have any query parameters for now, so we'll // pass 'nil' as the third parameter. req, err := http.NewRequest("GET", "/health", nil) if err != nil { t.Fatal(err) } // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response. rr := httptest.NewRecorder() handler := http.HandlerFunc(HealthCheckHandler) // Our handlers satisfy http.Handler, so we can call their ServeHTTP method // directly and pass in our Request and ResponseRecorder. handler.ServeHTTP(rr, req) // Check the status code is what we expect. if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } // Check the response body is what we expect. expected := `{"alive": true}` if rr.Body.String() != expected { t.Errorf("handler returned unexpected body: got %v want %v", rr.Body.String(), expected) } } ``` In the case that our routes have [variables](#examples), we can pass those in the request. We could write [table-driven tests](https://dave.cheney.net/2013/06/09/writing-table-driven-tests-in-go) to test multiple possible route variables as needed. ```go // endpoints.go func main() { r := mux.NewRouter() // A route with a route variable: r.HandleFunc("/metrics/{type}", MetricsHandler) log.Fatal(http.ListenAndServe("localhost:8080", r)) } ``` Our test file, with a table-driven test of `routeVariables`: ```go // endpoints_test.go func TestMetricsHandler(t *testing.T) { tt := []struct{ routeVariable string shouldPass bool }{ {"goroutines", true}, {"heap", true}, {"counters", true}, {"queries", true}, {"adhadaeqm3k", false}, } for _, tc := range tt { path := fmt.Sprintf("/metrics/%s", tc.routeVariable) req, err := http.NewRequest("GET", path, nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() handler := http.HandlerFunc(MetricsHandler) handler.ServeHTTP(rr, req) // In this case, our MetricsHandler returns a non-200 response // for a route variable it doesn't know about. if rr.Code == http.StatusOK && !tc.shouldPass { t.Errorf("handler should have failed on routeVariable %s: got %v want %v", tc.routeVariable, rr.Code, http.StatusOK) } } } ``` ## Full Example Here's a complete, runnable example of a small `mux` based server: ```go package main import ( "net/http" "log" "github.com/gorilla/mux" ) func YourHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Gorilla!\n")) } func main() { r := mux.NewRouter() // Routes consist of a path and a handler function. r.HandleFunc("/", YourHandler) // Bind to a port and pass our router in log.Fatal(http.ListenAndServe(":8000", r)) } ``` ## License BSD licensed. See the LICENSE file for details. ================================================ FILE: vendor/github.com/gorilla/mux/context_gorilla.go ================================================ // +build !go1.7 package mux import ( "net/http" "github.com/gorilla/context" ) func contextGet(r *http.Request, key interface{}) interface{} { return context.Get(r, key) } func contextSet(r *http.Request, key, val interface{}) *http.Request { if val == nil { return r } context.Set(r, key, val) return r } func contextClear(r *http.Request) { context.Clear(r) } ================================================ FILE: vendor/github.com/gorilla/mux/context_native.go ================================================ // +build go1.7 package mux import ( "context" "net/http" ) func contextGet(r *http.Request, key interface{}) interface{} { return r.Context().Value(key) } func contextSet(r *http.Request, key, val interface{}) *http.Request { if val == nil { return r } return r.WithContext(context.WithValue(r.Context(), key, val)) } func contextClear(r *http.Request) { return } ================================================ FILE: vendor/github.com/gorilla/mux/doc.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* Package mux implements a request router and dispatcher. The name mux stands for "HTTP request multiplexer". Like the standard http.ServeMux, mux.Router matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are: * Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers. * URL hosts, paths and query values can have variables with an optional regular expression. * Registered URLs can be built, or "reversed", which helps maintaining references to resources. * Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching. * It implements the http.Handler interface so it is compatible with the standard http.ServeMux. Let's start registering a couple of URL paths and handlers: func main() { r := mux.NewRouter() r.HandleFunc("/", HomeHandler) r.HandleFunc("/products", ProductsHandler) r.HandleFunc("/articles", ArticlesHandler) http.Handle("/", r) } Here we register three routes mapping URL paths to handlers. This is equivalent to how http.HandleFunc() works: if an incoming request URL matches one of the paths, the corresponding handler is called passing (http.ResponseWriter, *http.Request) as parameters. Paths can have variables. They are defined using the format {name} or {name:pattern}. If a regular expression pattern is not defined, the matched variable will be anything until the next slash. For example: r := mux.NewRouter() r.HandleFunc("/products/{key}", ProductHandler) r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler) r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler) Groups can be used inside patterns, as long as they are non-capturing (?:re). For example: r.HandleFunc("/articles/{category}/{sort:(?:asc|desc|new)}", ArticlesCategoryHandler) The names are used to create a map of route variables which can be retrieved calling mux.Vars(): vars := mux.Vars(request) category := vars["category"] Note that if any capturing groups are present, mux will panic() during parsing. To prevent this, convert any capturing groups to non-capturing, e.g. change "/{sort:(asc|desc)}" to "/{sort:(?:asc|desc)}". This is a change from prior versions which behaved unpredictably when capturing groups were present. And this is all you need to know about the basic usage. More advanced options are explained below. Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables: r := mux.NewRouter() // Only matches if domain is "www.example.com". r.Host("www.example.com") // Matches a dynamic subdomain. r.Host("{subdomain:[a-z]+}.domain.com") There are several other matchers that can be added. To match path prefixes: r.PathPrefix("/products/") ...or HTTP methods: r.Methods("GET", "POST") ...or URL schemes: r.Schemes("https") ...or header values: r.Headers("X-Requested-With", "XMLHttpRequest") ...or query values: r.Queries("key", "value") ...or to use a custom matcher function: r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool { return r.ProtoMajor == 0 }) ...and finally, it is possible to combine several matchers in a single route: r.HandleFunc("/products", ProductsHandler). Host("www.example.com"). Methods("GET"). Schemes("http") Setting the same matching conditions again and again can be boring, so we have a way to group several routes that share the same requirements. We call it "subrouting". For example, let's say we have several URLs that should only match when the host is "www.example.com". Create a route for that host and get a "subrouter" from it: r := mux.NewRouter() s := r.Host("www.example.com").Subrouter() Then register routes in the subrouter: s.HandleFunc("/products/", ProductsHandler) s.HandleFunc("/products/{key}", ProductHandler) s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler) The three URL paths we registered above will only be tested if the domain is "www.example.com", because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route. Subrouters can be used to create domain or path "namespaces": you define subrouters in a central place and then parts of the app can register its paths relatively to a given subrouter. There's one more thing about subroutes. When a subrouter has a path prefix, the inner routes use it as base for their paths: r := mux.NewRouter() s := r.PathPrefix("/products").Subrouter() // "/products/" s.HandleFunc("/", ProductsHandler) // "/products/{key}/" s.HandleFunc("/{key}/", ProductHandler) // "/products/{key}/details" s.HandleFunc("/{key}/details", ProductDetailsHandler) Note that the path provided to PathPrefix() represents a "wildcard": calling PathPrefix("/static/").Handler(...) means that the handler will be passed any request that matches "/static/*". This makes it easy to serve static files with mux: func main() { var dir string flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir") flag.Parse() r := mux.NewRouter() // This will serve files under http://localhost:8000/static/ r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir)))) srv := &http.Server{ Handler: r, Addr: "127.0.0.1:8000", // Good practice: enforce timeouts for servers you create! WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, } log.Fatal(srv.ListenAndServe()) } Now let's see how to build registered URLs. Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling Name() on a route. For example: r := mux.NewRouter() r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). Name("article") To build a URL, get the route and call the URL() method, passing a sequence of key/value pairs for the route variables. For the previous route, we would do: url, err := r.Get("article").URL("category", "technology", "id", "42") ...and the result will be a url.URL with the following path: "/articles/technology/42" This also works for host and query value variables: r := mux.NewRouter() r.Host("{subdomain}.domain.com"). Path("/articles/{category}/{id:[0-9]+}"). Queries("filter", "{filter}"). HandlerFunc(ArticleHandler). Name("article") // url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla" url, err := r.Get("article").URL("subdomain", "news", "category", "technology", "id", "42", "filter", "gorilla") All variables defined in the route are required, and their values must conform to the corresponding patterns. These requirements guarantee that a generated URL will always match a registered route -- the only exception is for explicitly defined "build-only" routes which never match. Regex support also exists for matching Headers within a route. For example, we could do: r.HeadersRegexp("Content-Type", "application/(text|json)") ...and the route will match both requests with a Content-Type of `application/json` as well as `application/text` There's also a way to build only the URL host or path for a route: use the methods URLHost() or URLPath() instead. For the previous route, we would do: // "http://news.domain.com/" host, err := r.Get("article").URLHost("subdomain", "news") // "/articles/technology/42" path, err := r.Get("article").URLPath("category", "technology", "id", "42") And if you use subrouters, host and path defined separately can be built as well: r := mux.NewRouter() s := r.Host("{subdomain}.domain.com").Subrouter() s.Path("/articles/{category}/{id:[0-9]+}"). HandlerFunc(ArticleHandler). Name("article") // "http://news.domain.com/articles/technology/42" url, err := r.Get("article").URL("subdomain", "news", "category", "technology", "id", "42") Mux supports the addition of middlewares to a Router, which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or ResponseWriter hijacking. type MiddlewareFunc func(http.Handler) http.Handler Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc (closures can access variables from the context where they are created). A very basic middleware which logs the URI of the request being handled could be written as: func simpleMw(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Do stuff here log.Println(r.RequestURI) // Call the next handler, which can be another middleware in the chain, or the final handler. next.ServeHTTP(w, r) }) } Middlewares can be added to a router using `Router.Use()`: r := mux.NewRouter() r.HandleFunc("/", handler) r.Use(simpleMw) A more complex authentication middleware, which maps session token to users, could be written as: // Define our struct type authenticationMiddleware struct { tokenUsers map[string]string } // Initialize it somewhere func (amw *authenticationMiddleware) Populate() { amw.tokenUsers["00000000"] = "user0" amw.tokenUsers["aaaaaaaa"] = "userA" amw.tokenUsers["05f717e5"] = "randomUser" amw.tokenUsers["deadbeef"] = "user0" } // Middleware function, which will be called for each request func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("X-Session-Token") if user, found := amw.tokenUsers[token]; found { // We found the token in our map log.Printf("Authenticated user %s\n", user) next.ServeHTTP(w, r) } else { http.Error(w, "Forbidden", http.StatusForbidden) } }) } r := mux.NewRouter() r.HandleFunc("/", handler) amw := authenticationMiddleware{} amw.Populate() r.Use(amw.Middleware) Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. */ package mux ================================================ FILE: vendor/github.com/gorilla/mux/middleware.go ================================================ package mux import "net/http" // MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler. // Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed // to it, and then calls the handler passed as parameter to the MiddlewareFunc. type MiddlewareFunc func(http.Handler) http.Handler // middleware interface is anything which implements a MiddlewareFunc named Middleware. type middleware interface { Middleware(handler http.Handler) http.Handler } // MiddlewareFunc also implements the middleware interface. func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler { return mw(handler) } // Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. func (r *Router) Use(mwf ...MiddlewareFunc) { for _, fn := range mwf { r.middlewares = append(r.middlewares, fn) } } // useInterface appends a middleware to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. func (r *Router) useInterface(mw middleware) { r.middlewares = append(r.middlewares, mw) } ================================================ FILE: vendor/github.com/gorilla/mux/mux.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mux import ( "errors" "fmt" "net/http" "path" "regexp" ) var ( ErrMethodMismatch = errors.New("method is not allowed") ErrNotFound = errors.New("no matching route was found") ) // NewRouter returns a new router instance. func NewRouter() *Router { return &Router{namedRoutes: make(map[string]*Route), KeepContext: false} } // Router registers routes to be matched and dispatches a handler. // // It implements the http.Handler interface, so it can be registered to serve // requests: // // var router = mux.NewRouter() // // func main() { // http.Handle("/", router) // } // // Or, for Google App Engine, register it in a init() function: // // func init() { // http.Handle("/", router) // } // // This will send all incoming requests to the router. type Router struct { // Configurable Handler to be used when no route matches. NotFoundHandler http.Handler // Configurable Handler to be used when the request method does not match the route. MethodNotAllowedHandler http.Handler // Parent route, if this is a subrouter. parent parentRoute // Routes to be matched, in order. routes []*Route // Routes by name for URL building. namedRoutes map[string]*Route // See Router.StrictSlash(). This defines the flag for new routes. strictSlash bool // See Router.SkipClean(). This defines the flag for new routes. skipClean bool // If true, do not clear the request context after handling the request. // This has no effect when go1.7+ is used, since the context is stored // on the request itself. KeepContext bool // see Router.UseEncodedPath(). This defines a flag for all routes. useEncodedPath bool // Slice of middlewares to be called after a match is found middlewares []middleware } // Match attempts to match the given request against the router's registered routes. // // If the request matches a route of this router or one of its subrouters the Route, // Handler, and Vars fields of the the match argument are filled and this function // returns true. // // If the request does not match any of this router's or its subrouters' routes // then this function returns false. If available, a reason for the match failure // will be filled in the match argument's MatchErr field. If the match failure type // (eg: not found) has a registered handler, the handler is assigned to the Handler // field of the match argument. func (r *Router) Match(req *http.Request, match *RouteMatch) bool { for _, route := range r.routes { if route.Match(req, match) { // Build middleware chain if no error was found if match.MatchErr == nil { for i := len(r.middlewares) - 1; i >= 0; i-- { match.Handler = r.middlewares[i].Middleware(match.Handler) } } return true } } if match.MatchErr == ErrMethodMismatch { if r.MethodNotAllowedHandler != nil { match.Handler = r.MethodNotAllowedHandler return true } else { return false } } // Closest match for a router (includes sub-routers) if r.NotFoundHandler != nil { match.Handler = r.NotFoundHandler match.MatchErr = ErrNotFound return true } match.MatchErr = ErrNotFound return false } // ServeHTTP dispatches the handler registered in the matched route. // // When there is a match, the route variables can be retrieved calling // mux.Vars(request). func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if !r.skipClean { path := req.URL.Path if r.useEncodedPath { path = req.URL.EscapedPath() } // Clean path to canonical form and redirect. if p := cleanPath(path); p != path { // Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query. // This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue: // http://code.google.com/p/go/issues/detail?id=5252 url := *req.URL url.Path = p p = url.String() w.Header().Set("Location", p) w.WriteHeader(http.StatusMovedPermanently) return } } var match RouteMatch var handler http.Handler if r.Match(req, &match) { handler = match.Handler req = setVars(req, match.Vars) req = setCurrentRoute(req, match.Route) } if handler == nil && match.MatchErr == ErrMethodMismatch { handler = methodNotAllowedHandler() } if handler == nil { handler = http.NotFoundHandler() } if !r.KeepContext { defer contextClear(req) } handler.ServeHTTP(w, req) } // Get returns a route registered with the given name. func (r *Router) Get(name string) *Route { return r.getNamedRoutes()[name] } // GetRoute returns a route registered with the given name. This method // was renamed to Get() and remains here for backwards compatibility. func (r *Router) GetRoute(name string) *Route { return r.getNamedRoutes()[name] } // StrictSlash defines the trailing slash behavior for new routes. The initial // value is false. // // When true, if the route path is "/path/", accessing "/path" will perform a redirect // to the former and vice versa. In other words, your application will always // see the path as specified in the route. // // When false, if the route path is "/path", accessing "/path/" will not match // this route and vice versa. // // The re-direct is a HTTP 301 (Moved Permanently). Note that when this is set for // routes with a non-idempotent method (e.g. POST, PUT), the subsequent re-directed // request will be made as a GET by most clients. Use middleware or client settings // to modify this behaviour as needed. // // Special case: when a route sets a path prefix using the PathPrefix() method, // strict slash is ignored for that route because the redirect behavior can't // be determined from a prefix alone. However, any subrouters created from that // route inherit the original StrictSlash setting. func (r *Router) StrictSlash(value bool) *Router { r.strictSlash = value return r } // SkipClean defines the path cleaning behaviour for new routes. The initial // value is false. Users should be careful about which routes are not cleaned // // When true, if the route path is "/path//to", it will remain with the double // slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/ // // When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will // become /fetch/http/xkcd.com/534 func (r *Router) SkipClean(value bool) *Router { r.skipClean = value return r } // UseEncodedPath tells the router to match the encoded original path // to the routes. // For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". // // If not called, the router will match the unencoded path to the routes. // For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to" func (r *Router) UseEncodedPath() *Router { r.useEncodedPath = true return r } // ---------------------------------------------------------------------------- // parentRoute // ---------------------------------------------------------------------------- func (r *Router) getBuildScheme() string { if r.parent != nil { return r.parent.getBuildScheme() } return "" } // getNamedRoutes returns the map where named routes are registered. func (r *Router) getNamedRoutes() map[string]*Route { if r.namedRoutes == nil { if r.parent != nil { r.namedRoutes = r.parent.getNamedRoutes() } else { r.namedRoutes = make(map[string]*Route) } } return r.namedRoutes } // getRegexpGroup returns regexp definitions from the parent route, if any. func (r *Router) getRegexpGroup() *routeRegexpGroup { if r.parent != nil { return r.parent.getRegexpGroup() } return nil } func (r *Router) buildVars(m map[string]string) map[string]string { if r.parent != nil { m = r.parent.buildVars(m) } return m } // ---------------------------------------------------------------------------- // Route factories // ---------------------------------------------------------------------------- // NewRoute registers an empty route. func (r *Router) NewRoute() *Route { route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath} r.routes = append(r.routes, route) return route } // Handle registers a new route with a matcher for the URL path. // See Route.Path() and Route.Handler(). func (r *Router) Handle(path string, handler http.Handler) *Route { return r.NewRoute().Path(path).Handler(handler) } // HandleFunc registers a new route with a matcher for the URL path. // See Route.Path() and Route.HandlerFunc(). func (r *Router) HandleFunc(path string, f func(http.ResponseWriter, *http.Request)) *Route { return r.NewRoute().Path(path).HandlerFunc(f) } // Headers registers a new route with a matcher for request header values. // See Route.Headers(). func (r *Router) Headers(pairs ...string) *Route { return r.NewRoute().Headers(pairs...) } // Host registers a new route with a matcher for the URL host. // See Route.Host(). func (r *Router) Host(tpl string) *Route { return r.NewRoute().Host(tpl) } // MatcherFunc registers a new route with a custom matcher function. // See Route.MatcherFunc(). func (r *Router) MatcherFunc(f MatcherFunc) *Route { return r.NewRoute().MatcherFunc(f) } // Methods registers a new route with a matcher for HTTP methods. // See Route.Methods(). func (r *Router) Methods(methods ...string) *Route { return r.NewRoute().Methods(methods...) } // Path registers a new route with a matcher for the URL path. // See Route.Path(). func (r *Router) Path(tpl string) *Route { return r.NewRoute().Path(tpl) } // PathPrefix registers a new route with a matcher for the URL path prefix. // See Route.PathPrefix(). func (r *Router) PathPrefix(tpl string) *Route { return r.NewRoute().PathPrefix(tpl) } // Queries registers a new route with a matcher for URL query values. // See Route.Queries(). func (r *Router) Queries(pairs ...string) *Route { return r.NewRoute().Queries(pairs...) } // Schemes registers a new route with a matcher for URL schemes. // See Route.Schemes(). func (r *Router) Schemes(schemes ...string) *Route { return r.NewRoute().Schemes(schemes...) } // BuildVarsFunc registers a new route with a custom function for modifying // route variables before building a URL. func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route { return r.NewRoute().BuildVarsFunc(f) } // Walk walks the router and all its sub-routers, calling walkFn for each route // in the tree. The routes are walked in the order they were added. Sub-routers // are explored depth-first. func (r *Router) Walk(walkFn WalkFunc) error { return r.walk(walkFn, []*Route{}) } // SkipRouter is used as a return value from WalkFuncs to indicate that the // router that walk is about to descend down to should be skipped. var SkipRouter = errors.New("skip this router") // WalkFunc is the type of the function called for each route visited by Walk. // At every invocation, it is given the current route, and the current router, // and a list of ancestor routes that lead to the current route. type WalkFunc func(route *Route, router *Router, ancestors []*Route) error func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error { for _, t := range r.routes { err := walkFn(t, r, ancestors) if err == SkipRouter { continue } if err != nil { return err } for _, sr := range t.matchers { if h, ok := sr.(*Router); ok { ancestors = append(ancestors, t) err := h.walk(walkFn, ancestors) if err != nil { return err } ancestors = ancestors[:len(ancestors)-1] } } if h, ok := t.handler.(*Router); ok { ancestors = append(ancestors, t) err := h.walk(walkFn, ancestors) if err != nil { return err } ancestors = ancestors[:len(ancestors)-1] } } return nil } // ---------------------------------------------------------------------------- // Context // ---------------------------------------------------------------------------- // RouteMatch stores information about a matched route. type RouteMatch struct { Route *Route Handler http.Handler Vars map[string]string // MatchErr is set to appropriate matching error // It is set to ErrMethodMismatch if there is a mismatch in // the request method and route method MatchErr error } type contextKey int const ( varsKey contextKey = iota routeKey ) // Vars returns the route variables for the current request, if any. func Vars(r *http.Request) map[string]string { if rv := contextGet(r, varsKey); rv != nil { return rv.(map[string]string) } return nil } // CurrentRoute returns the matched route for the current request, if any. // This only works when called inside the handler of the matched route // because the matched route is stored in the request context which is cleared // after the handler returns, unless the KeepContext option is set on the // Router. func CurrentRoute(r *http.Request) *Route { if rv := contextGet(r, routeKey); rv != nil { return rv.(*Route) } return nil } func setVars(r *http.Request, val interface{}) *http.Request { return contextSet(r, varsKey, val) } func setCurrentRoute(r *http.Request, val interface{}) *http.Request { return contextSet(r, routeKey, val) } // ---------------------------------------------------------------------------- // Helpers // ---------------------------------------------------------------------------- // cleanPath returns the canonical path for p, eliminating . and .. elements. // Borrowed from the net/http package. func cleanPath(p string) string { if p == "" { return "/" } if p[0] != '/' { p = "/" + p } np := path.Clean(p) // path.Clean removes trailing slash except for root; // put the trailing slash back if necessary. if p[len(p)-1] == '/' && np != "/" { np += "/" } return np } // uniqueVars returns an error if two slices contain duplicated strings. func uniqueVars(s1, s2 []string) error { for _, v1 := range s1 { for _, v2 := range s2 { if v1 == v2 { return fmt.Errorf("mux: duplicated route variable %q", v2) } } } return nil } // checkPairs returns the count of strings passed in, and an error if // the count is not an even number. func checkPairs(pairs ...string) (int, error) { length := len(pairs) if length%2 != 0 { return length, fmt.Errorf( "mux: number of parameters must be multiple of 2, got %v", pairs) } return length, nil } // mapFromPairsToString converts variadic string parameters to a // string to string map. func mapFromPairsToString(pairs ...string) (map[string]string, error) { length, err := checkPairs(pairs...) if err != nil { return nil, err } m := make(map[string]string, length/2) for i := 0; i < length; i += 2 { m[pairs[i]] = pairs[i+1] } return m, nil } // mapFromPairsToRegex converts variadic string parameters to a // string to regex map. func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) { length, err := checkPairs(pairs...) if err != nil { return nil, err } m := make(map[string]*regexp.Regexp, length/2) for i := 0; i < length; i += 2 { regex, err := regexp.Compile(pairs[i+1]) if err != nil { return nil, err } m[pairs[i]] = regex } return m, nil } // matchInArray returns true if the given string value is in the array. func matchInArray(arr []string, value string) bool { for _, v := range arr { if v == value { return true } } return false } // matchMapWithString returns true if the given key/value pairs exist in a given map. func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool { for k, v := range toCheck { // Check if key exists. if canonicalKey { k = http.CanonicalHeaderKey(k) } if values := toMatch[k]; values == nil { return false } else if v != "" { // If value was defined as an empty string we only check that the // key exists. Otherwise we also check for equality. valueExists := false for _, value := range values { if v == value { valueExists = true break } } if !valueExists { return false } } } return true } // matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against // the given regex func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool { for k, v := range toCheck { // Check if key exists. if canonicalKey { k = http.CanonicalHeaderKey(k) } if values := toMatch[k]; values == nil { return false } else if v != nil { // If value was defined as an empty string we only check that the // key exists. Otherwise we also check for equality. valueExists := false for _, value := range values { if v.MatchString(value) { valueExists = true break } } if !valueExists { return false } } } return true } // methodNotAllowed replies to the request with an HTTP status code 405. func methodNotAllowed(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusMethodNotAllowed) } // methodNotAllowedHandler returns a simple request handler // that replies to each request with a status code 405. func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) } ================================================ FILE: vendor/github.com/gorilla/mux/regexp.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mux import ( "bytes" "fmt" "net/http" "net/url" "regexp" "strconv" "strings" ) type routeRegexpOptions struct { strictSlash bool useEncodedPath bool } type regexpType int const ( regexpTypePath regexpType = 0 regexpTypeHost regexpType = 1 regexpTypePrefix regexpType = 2 regexpTypeQuery regexpType = 3 ) // newRouteRegexp parses a route template and returns a routeRegexp, // used to match a host, a path or a query string. // // It will extract named variables, assemble a regexp to be matched, create // a "reverse" template to build URLs and compile regexps to validate variable // values used in URL building. // // Previously we accepted only Python-like identifiers for variable // names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that // name and pattern can't be empty, and names can't contain a colon. func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*routeRegexp, error) { // Check if it is well-formed. idxs, errBraces := braceIndices(tpl) if errBraces != nil { return nil, errBraces } // Backup the original. template := tpl // Now let's parse it. defaultPattern := "[^/]+" if typ == regexpTypeQuery { defaultPattern = ".*" } else if typ == regexpTypeHost { defaultPattern = "[^.]+" } // Only match strict slash if not matching if typ != regexpTypePath { options.strictSlash = false } // Set a flag for strictSlash. endSlash := false if options.strictSlash && strings.HasSuffix(tpl, "/") { tpl = tpl[:len(tpl)-1] endSlash = true } varsN := make([]string, len(idxs)/2) varsR := make([]*regexp.Regexp, len(idxs)/2) pattern := bytes.NewBufferString("") pattern.WriteByte('^') reverse := bytes.NewBufferString("") var end int var err error for i := 0; i < len(idxs); i += 2 { // Set all values we are interested in. raw := tpl[end:idxs[i]] end = idxs[i+1] parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2) name := parts[0] patt := defaultPattern if len(parts) == 2 { patt = parts[1] } // Name or pattern can't be empty. if name == "" || patt == "" { return nil, fmt.Errorf("mux: missing name or pattern in %q", tpl[idxs[i]:end]) } // Build the regexp pattern. fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt) // Build the reverse template. fmt.Fprintf(reverse, "%s%%s", raw) // Append variable name and compiled pattern. varsN[i/2] = name varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt)) if err != nil { return nil, err } } // Add the remaining. raw := tpl[end:] pattern.WriteString(regexp.QuoteMeta(raw)) if options.strictSlash { pattern.WriteString("[/]?") } if typ == regexpTypeQuery { // Add the default pattern if the query value is empty if queryVal := strings.SplitN(template, "=", 2)[1]; queryVal == "" { pattern.WriteString(defaultPattern) } } if typ != regexpTypePrefix { pattern.WriteByte('$') } reverse.WriteString(raw) if endSlash { reverse.WriteByte('/') } // Compile full regexp. reg, errCompile := regexp.Compile(pattern.String()) if errCompile != nil { return nil, errCompile } // Check for capturing groups which used to work in older versions if reg.NumSubexp() != len(idxs)/2 { panic(fmt.Sprintf("route %s contains capture groups in its regexp. ", template) + "Only non-capturing groups are accepted: e.g. (?:pattern) instead of (pattern)") } // Done! return &routeRegexp{ template: template, regexpType: typ, options: options, regexp: reg, reverse: reverse.String(), varsN: varsN, varsR: varsR, }, nil } // routeRegexp stores a regexp to match a host or path and information to // collect and validate route variables. type routeRegexp struct { // The unmodified template. template string // The type of match regexpType regexpType // Options for matching options routeRegexpOptions // Expanded regexp. regexp *regexp.Regexp // Reverse template. reverse string // Variable names. varsN []string // Variable regexps (validators). varsR []*regexp.Regexp } // Match matches the regexp against the URL host or path. func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool { if r.regexpType != regexpTypeHost { if r.regexpType == regexpTypeQuery { return r.matchQueryString(req) } path := req.URL.Path if r.options.useEncodedPath { path = req.URL.EscapedPath() } return r.regexp.MatchString(path) } return r.regexp.MatchString(getHost(req)) } // url builds a URL part using the given values. func (r *routeRegexp) url(values map[string]string) (string, error) { urlValues := make([]interface{}, len(r.varsN)) for k, v := range r.varsN { value, ok := values[v] if !ok { return "", fmt.Errorf("mux: missing route variable %q", v) } if r.regexpType == regexpTypeQuery { value = url.QueryEscape(value) } urlValues[k] = value } rv := fmt.Sprintf(r.reverse, urlValues...) if !r.regexp.MatchString(rv) { // The URL is checked against the full regexp, instead of checking // individual variables. This is faster but to provide a good error // message, we check individual regexps if the URL doesn't match. for k, v := range r.varsN { if !r.varsR[k].MatchString(values[v]) { return "", fmt.Errorf( "mux: variable %q doesn't match, expected %q", values[v], r.varsR[k].String()) } } } return rv, nil } // getURLQuery returns a single query parameter from a request URL. // For a URL with foo=bar&baz=ding, we return only the relevant key // value pair for the routeRegexp. func (r *routeRegexp) getURLQuery(req *http.Request) string { if r.regexpType != regexpTypeQuery { return "" } templateKey := strings.SplitN(r.template, "=", 2)[0] for key, vals := range req.URL.Query() { if key == templateKey && len(vals) > 0 { return key + "=" + vals[0] } } return "" } func (r *routeRegexp) matchQueryString(req *http.Request) bool { return r.regexp.MatchString(r.getURLQuery(req)) } // braceIndices returns the first level curly brace indices from a string. // It returns an error in case of unbalanced braces. func braceIndices(s string) ([]int, error) { var level, idx int var idxs []int for i := 0; i < len(s); i++ { switch s[i] { case '{': if level++; level == 1 { idx = i } case '}': if level--; level == 0 { idxs = append(idxs, idx, i+1) } else if level < 0 { return nil, fmt.Errorf("mux: unbalanced braces in %q", s) } } } if level != 0 { return nil, fmt.Errorf("mux: unbalanced braces in %q", s) } return idxs, nil } // varGroupName builds a capturing group name for the indexed variable. func varGroupName(idx int) string { return "v" + strconv.Itoa(idx) } // ---------------------------------------------------------------------------- // routeRegexpGroup // ---------------------------------------------------------------------------- // routeRegexpGroup groups the route matchers that carry variables. type routeRegexpGroup struct { host *routeRegexp path *routeRegexp queries []*routeRegexp } // setMatch extracts the variables from the URL once a route matches. func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) { // Store host variables. if v.host != nil { host := getHost(req) matches := v.host.regexp.FindStringSubmatchIndex(host) if len(matches) > 0 { extractVars(host, matches, v.host.varsN, m.Vars) } } path := req.URL.Path if r.useEncodedPath { path = req.URL.EscapedPath() } // Store path variables. if v.path != nil { matches := v.path.regexp.FindStringSubmatchIndex(path) if len(matches) > 0 { extractVars(path, matches, v.path.varsN, m.Vars) // Check if we should redirect. if v.path.options.strictSlash { p1 := strings.HasSuffix(path, "/") p2 := strings.HasSuffix(v.path.template, "/") if p1 != p2 { u, _ := url.Parse(req.URL.String()) if p1 { u.Path = u.Path[:len(u.Path)-1] } else { u.Path += "/" } m.Handler = http.RedirectHandler(u.String(), 301) } } } } // Store query string variables. for _, q := range v.queries { queryURL := q.getURLQuery(req) matches := q.regexp.FindStringSubmatchIndex(queryURL) if len(matches) > 0 { extractVars(queryURL, matches, q.varsN, m.Vars) } } } // getHost tries its best to return the request host. func getHost(r *http.Request) string { if r.URL.IsAbs() { return r.URL.Host } host := r.Host // Slice off any port information. if i := strings.Index(host, ":"); i != -1 { host = host[:i] } return host } func extractVars(input string, matches []int, names []string, output map[string]string) { for i, name := range names { output[name] = input[matches[2*i+2]:matches[2*i+3]] } } ================================================ FILE: vendor/github.com/gorilla/mux/route.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mux import ( "errors" "fmt" "net/http" "net/url" "regexp" "strings" ) // Route stores information to match a request and build URLs. type Route struct { // Parent where the route was registered (a Router). parent parentRoute // Request handler for the route. handler http.Handler // List of matchers. matchers []matcher // Manager for the variables from host and path. regexp *routeRegexpGroup // If true, when the path pattern is "/path/", accessing "/path" will // redirect to the former and vice versa. strictSlash bool // If true, when the path pattern is "/path//to", accessing "/path//to" // will not redirect skipClean bool // If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to" useEncodedPath bool // The scheme used when building URLs. buildScheme string // If true, this route never matches: it is only used to build URLs. buildOnly bool // The name used to build URLs. name string // Error resulted from building a route. err error buildVarsFunc BuildVarsFunc } func (r *Route) SkipClean() bool { return r.skipClean } // Match matches the route against the request. func (r *Route) Match(req *http.Request, match *RouteMatch) bool { if r.buildOnly || r.err != nil { return false } var matchErr error // Match everything. for _, m := range r.matchers { if matched := m.Match(req, match); !matched { if _, ok := m.(methodMatcher); ok { matchErr = ErrMethodMismatch continue } matchErr = nil return false } } if matchErr != nil { match.MatchErr = matchErr return false } if match.MatchErr == ErrMethodMismatch { // We found a route which matches request method, clear MatchErr match.MatchErr = nil // Then override the mis-matched handler match.Handler = r.handler } // Yay, we have a match. Let's collect some info about it. if match.Route == nil { match.Route = r } if match.Handler == nil { match.Handler = r.handler } if match.Vars == nil { match.Vars = make(map[string]string) } // Set variables. if r.regexp != nil { r.regexp.setMatch(req, match, r) } return true } // ---------------------------------------------------------------------------- // Route attributes // ---------------------------------------------------------------------------- // GetError returns an error resulted from building the route, if any. func (r *Route) GetError() error { return r.err } // BuildOnly sets the route to never match: it is only used to build URLs. func (r *Route) BuildOnly() *Route { r.buildOnly = true return r } // Handler -------------------------------------------------------------------- // Handler sets a handler for the route. func (r *Route) Handler(handler http.Handler) *Route { if r.err == nil { r.handler = handler } return r } // HandlerFunc sets a handler function for the route. func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route { return r.Handler(http.HandlerFunc(f)) } // GetHandler returns the handler for the route, if any. func (r *Route) GetHandler() http.Handler { return r.handler } // Name ----------------------------------------------------------------------- // Name sets the name for the route, used to build URLs. // If the name was registered already it will be overwritten. func (r *Route) Name(name string) *Route { if r.name != "" { r.err = fmt.Errorf("mux: route already has name %q, can't set %q", r.name, name) } if r.err == nil { r.name = name r.getNamedRoutes()[name] = r } return r } // GetName returns the name for the route, if any. func (r *Route) GetName() string { return r.name } // ---------------------------------------------------------------------------- // Matchers // ---------------------------------------------------------------------------- // matcher types try to match a request. type matcher interface { Match(*http.Request, *RouteMatch) bool } // addMatcher adds a matcher to the route. func (r *Route) addMatcher(m matcher) *Route { if r.err == nil { r.matchers = append(r.matchers, m) } return r } // addRegexpMatcher adds a host or path matcher and builder to a route. func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error { if r.err != nil { return r.err } r.regexp = r.getRegexpGroup() if typ == regexpTypePath || typ == regexpTypePrefix { if len(tpl) > 0 && tpl[0] != '/' { return fmt.Errorf("mux: path must start with a slash, got %q", tpl) } if r.regexp.path != nil { tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl } } rr, err := newRouteRegexp(tpl, typ, routeRegexpOptions{ strictSlash: r.strictSlash, useEncodedPath: r.useEncodedPath, }) if err != nil { return err } for _, q := range r.regexp.queries { if err = uniqueVars(rr.varsN, q.varsN); err != nil { return err } } if typ == regexpTypeHost { if r.regexp.path != nil { if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil { return err } } r.regexp.host = rr } else { if r.regexp.host != nil { if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil { return err } } if typ == regexpTypeQuery { r.regexp.queries = append(r.regexp.queries, rr) } else { r.regexp.path = rr } } r.addMatcher(rr) return nil } // Headers -------------------------------------------------------------------- // headerMatcher matches the request against header values. type headerMatcher map[string]string func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool { return matchMapWithString(m, r.Header, true) } // Headers adds a matcher for request header values. // It accepts a sequence of key/value pairs to be matched. For example: // // r := mux.NewRouter() // r.Headers("Content-Type", "application/json", // "X-Requested-With", "XMLHttpRequest") // // The above route will only match if both request header values match. // If the value is an empty string, it will match any value if the key is set. func (r *Route) Headers(pairs ...string) *Route { if r.err == nil { var headers map[string]string headers, r.err = mapFromPairsToString(pairs...) return r.addMatcher(headerMatcher(headers)) } return r } // headerRegexMatcher matches the request against the route given a regex for the header type headerRegexMatcher map[string]*regexp.Regexp func (m headerRegexMatcher) Match(r *http.Request, match *RouteMatch) bool { return matchMapWithRegex(m, r.Header, true) } // HeadersRegexp accepts a sequence of key/value pairs, where the value has regex // support. For example: // // r := mux.NewRouter() // r.HeadersRegexp("Content-Type", "application/(text|json)", // "X-Requested-With", "XMLHttpRequest") // // The above route will only match if both the request header matches both regular expressions. // If the value is an empty string, it will match any value if the key is set. // Use the start and end of string anchors (^ and $) to match an exact value. func (r *Route) HeadersRegexp(pairs ...string) *Route { if r.err == nil { var headers map[string]*regexp.Regexp headers, r.err = mapFromPairsToRegex(pairs...) return r.addMatcher(headerRegexMatcher(headers)) } return r } // Host ----------------------------------------------------------------------- // Host adds a matcher for the URL host. // It accepts a template with zero or more URL variables enclosed by {}. // Variables can define an optional regexp pattern to be matched: // // - {name} matches anything until the next dot. // // - {name:pattern} matches the given regexp pattern. // // For example: // // r := mux.NewRouter() // r.Host("www.example.com") // r.Host("{subdomain}.domain.com") // r.Host("{subdomain:[a-z]+}.domain.com") // // Variable names must be unique in a given route. They can be retrieved // calling mux.Vars(request). func (r *Route) Host(tpl string) *Route { r.err = r.addRegexpMatcher(tpl, regexpTypeHost) return r } // MatcherFunc ---------------------------------------------------------------- // MatcherFunc is the function signature used by custom matchers. type MatcherFunc func(*http.Request, *RouteMatch) bool // Match returns the match for a given request. func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool { return m(r, match) } // MatcherFunc adds a custom function to be used as request matcher. func (r *Route) MatcherFunc(f MatcherFunc) *Route { return r.addMatcher(f) } // Methods -------------------------------------------------------------------- // methodMatcher matches the request against HTTP methods. type methodMatcher []string func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool { return matchInArray(m, r.Method) } // Methods adds a matcher for HTTP methods. // It accepts a sequence of one or more methods to be matched, e.g.: // "GET", "POST", "PUT". func (r *Route) Methods(methods ...string) *Route { for k, v := range methods { methods[k] = strings.ToUpper(v) } return r.addMatcher(methodMatcher(methods)) } // Path ----------------------------------------------------------------------- // Path adds a matcher for the URL path. // It accepts a template with zero or more URL variables enclosed by {}. The // template must start with a "/". // Variables can define an optional regexp pattern to be matched: // // - {name} matches anything until the next slash. // // - {name:pattern} matches the given regexp pattern. // // For example: // // r := mux.NewRouter() // r.Path("/products/").Handler(ProductsHandler) // r.Path("/products/{key}").Handler(ProductsHandler) // r.Path("/articles/{category}/{id:[0-9]+}"). // Handler(ArticleHandler) // // Variable names must be unique in a given route. They can be retrieved // calling mux.Vars(request). func (r *Route) Path(tpl string) *Route { r.err = r.addRegexpMatcher(tpl, regexpTypePath) return r } // PathPrefix ----------------------------------------------------------------- // PathPrefix adds a matcher for the URL path prefix. This matches if the given // template is a prefix of the full URL path. See Route.Path() for details on // the tpl argument. // // Note that it does not treat slashes specially ("/foobar/" will be matched by // the prefix "/foo") so you may want to use a trailing slash here. // // Also note that the setting of Router.StrictSlash() has no effect on routes // with a PathPrefix matcher. func (r *Route) PathPrefix(tpl string) *Route { r.err = r.addRegexpMatcher(tpl, regexpTypePrefix) return r } // Query ---------------------------------------------------------------------- // Queries adds a matcher for URL query values. // It accepts a sequence of key/value pairs. Values may define variables. // For example: // // r := mux.NewRouter() // r.Queries("foo", "bar", "id", "{id:[0-9]+}") // // The above route will only match if the URL contains the defined queries // values, e.g.: ?foo=bar&id=42. // // It the value is an empty string, it will match any value if the key is set. // // Variables can define an optional regexp pattern to be matched: // // - {name} matches anything until the next slash. // // - {name:pattern} matches the given regexp pattern. func (r *Route) Queries(pairs ...string) *Route { length := len(pairs) if length%2 != 0 { r.err = fmt.Errorf( "mux: number of parameters must be multiple of 2, got %v", pairs) return nil } for i := 0; i < length; i += 2 { if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], regexpTypeQuery); r.err != nil { return r } } return r } // Schemes -------------------------------------------------------------------- // schemeMatcher matches the request against URL schemes. type schemeMatcher []string func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool { return matchInArray(m, r.URL.Scheme) } // Schemes adds a matcher for URL schemes. // It accepts a sequence of schemes to be matched, e.g.: "http", "https". func (r *Route) Schemes(schemes ...string) *Route { for k, v := range schemes { schemes[k] = strings.ToLower(v) } if r.buildScheme == "" && len(schemes) > 0 { r.buildScheme = schemes[0] } return r.addMatcher(schemeMatcher(schemes)) } // BuildVarsFunc -------------------------------------------------------------- // BuildVarsFunc is the function signature used by custom build variable // functions (which can modify route variables before a route's URL is built). type BuildVarsFunc func(map[string]string) map[string]string // BuildVarsFunc adds a custom function to be used to modify build variables // before a route's URL is built. func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route { r.buildVarsFunc = f return r } // Subrouter ------------------------------------------------------------------ // Subrouter creates a subrouter for the route. // // It will test the inner routes only if the parent route matched. For example: // // r := mux.NewRouter() // s := r.Host("www.example.com").Subrouter() // s.HandleFunc("/products/", ProductsHandler) // s.HandleFunc("/products/{key}", ProductHandler) // s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler) // // Here, the routes registered in the subrouter won't be tested if the host // doesn't match. func (r *Route) Subrouter() *Router { router := &Router{parent: r, strictSlash: r.strictSlash} r.addMatcher(router) return router } // ---------------------------------------------------------------------------- // URL building // ---------------------------------------------------------------------------- // URL builds a URL for the route. // // It accepts a sequence of key/value pairs for the route variables. For // example, given this route: // // r := mux.NewRouter() // r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). // Name("article") // // ...a URL for it can be built using: // // url, err := r.Get("article").URL("category", "technology", "id", "42") // // ...which will return an url.URL with the following path: // // "/articles/technology/42" // // This also works for host variables: // // r := mux.NewRouter() // r.Host("{subdomain}.domain.com"). // HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler). // Name("article") // // // url.String() will be "http://news.domain.com/articles/technology/42" // url, err := r.Get("article").URL("subdomain", "news", // "category", "technology", // "id", "42") // // All variables defined in the route are required, and their values must // conform to the corresponding patterns. func (r *Route) URL(pairs ...string) (*url.URL, error) { if r.err != nil { return nil, r.err } if r.regexp == nil { return nil, errors.New("mux: route doesn't have a host or path") } values, err := r.prepareVars(pairs...) if err != nil { return nil, err } var scheme, host, path string queries := make([]string, 0, len(r.regexp.queries)) if r.regexp.host != nil { if host, err = r.regexp.host.url(values); err != nil { return nil, err } scheme = "http" if s := r.getBuildScheme(); s != "" { scheme = s } } if r.regexp.path != nil { if path, err = r.regexp.path.url(values); err != nil { return nil, err } } for _, q := range r.regexp.queries { var query string if query, err = q.url(values); err != nil { return nil, err } queries = append(queries, query) } return &url.URL{ Scheme: scheme, Host: host, Path: path, RawQuery: strings.Join(queries, "&"), }, nil } // URLHost builds the host part of the URL for a route. See Route.URL(). // // The route must have a host defined. func (r *Route) URLHost(pairs ...string) (*url.URL, error) { if r.err != nil { return nil, r.err } if r.regexp == nil || r.regexp.host == nil { return nil, errors.New("mux: route doesn't have a host") } values, err := r.prepareVars(pairs...) if err != nil { return nil, err } host, err := r.regexp.host.url(values) if err != nil { return nil, err } u := &url.URL{ Scheme: "http", Host: host, } if s := r.getBuildScheme(); s != "" { u.Scheme = s } return u, nil } // URLPath builds the path part of the URL for a route. See Route.URL(). // // The route must have a path defined. func (r *Route) URLPath(pairs ...string) (*url.URL, error) { if r.err != nil { return nil, r.err } if r.regexp == nil || r.regexp.path == nil { return nil, errors.New("mux: route doesn't have a path") } values, err := r.prepareVars(pairs...) if err != nil { return nil, err } path, err := r.regexp.path.url(values) if err != nil { return nil, err } return &url.URL{ Path: path, }, nil } // GetPathTemplate returns the template used to build the // route match. // This is useful for building simple REST API documentation and for instrumentation // against third-party services. // An error will be returned if the route does not define a path. func (r *Route) GetPathTemplate() (string, error) { if r.err != nil { return "", r.err } if r.regexp == nil || r.regexp.path == nil { return "", errors.New("mux: route doesn't have a path") } return r.regexp.path.template, nil } // GetPathRegexp returns the expanded regular expression used to match route path. // This is useful for building simple REST API documentation and for instrumentation // against third-party services. // An error will be returned if the route does not define a path. func (r *Route) GetPathRegexp() (string, error) { if r.err != nil { return "", r.err } if r.regexp == nil || r.regexp.path == nil { return "", errors.New("mux: route does not have a path") } return r.regexp.path.regexp.String(), nil } // GetQueriesRegexp returns the expanded regular expressions used to match the // route queries. // This is useful for building simple REST API documentation and for instrumentation // against third-party services. // An error will be returned if the route does not have queries. func (r *Route) GetQueriesRegexp() ([]string, error) { if r.err != nil { return nil, r.err } if r.regexp == nil || r.regexp.queries == nil { return nil, errors.New("mux: route doesn't have queries") } var queries []string for _, query := range r.regexp.queries { queries = append(queries, query.regexp.String()) } return queries, nil } // GetQueriesTemplates returns the templates used to build the // query matching. // This is useful for building simple REST API documentation and for instrumentation // against third-party services. // An error will be returned if the route does not define queries. func (r *Route) GetQueriesTemplates() ([]string, error) { if r.err != nil { return nil, r.err } if r.regexp == nil || r.regexp.queries == nil { return nil, errors.New("mux: route doesn't have queries") } var queries []string for _, query := range r.regexp.queries { queries = append(queries, query.template) } return queries, nil } // GetMethods returns the methods the route matches against // This is useful for building simple REST API documentation and for instrumentation // against third-party services. // An error will be returned if route does not have methods. func (r *Route) GetMethods() ([]string, error) { if r.err != nil { return nil, r.err } for _, m := range r.matchers { if methods, ok := m.(methodMatcher); ok { return []string(methods), nil } } return nil, errors.New("mux: route doesn't have methods") } // GetHostTemplate returns the template used to build the // route match. // This is useful for building simple REST API documentation and for instrumentation // against third-party services. // An error will be returned if the route does not define a host. func (r *Route) GetHostTemplate() (string, error) { if r.err != nil { return "", r.err } if r.regexp == nil || r.regexp.host == nil { return "", errors.New("mux: route doesn't have a host") } return r.regexp.host.template, nil } // prepareVars converts the route variable pairs into a map. If the route has a // BuildVarsFunc, it is invoked. func (r *Route) prepareVars(pairs ...string) (map[string]string, error) { m, err := mapFromPairsToString(pairs...) if err != nil { return nil, err } return r.buildVars(m), nil } func (r *Route) buildVars(m map[string]string) map[string]string { if r.parent != nil { m = r.parent.buildVars(m) } if r.buildVarsFunc != nil { m = r.buildVarsFunc(m) } return m } // ---------------------------------------------------------------------------- // parentRoute // ---------------------------------------------------------------------------- // parentRoute allows routes to know about parent host and path definitions. type parentRoute interface { getBuildScheme() string getNamedRoutes() map[string]*Route getRegexpGroup() *routeRegexpGroup buildVars(map[string]string) map[string]string } func (r *Route) getBuildScheme() string { if r.buildScheme != "" { return r.buildScheme } if r.parent != nil { return r.parent.getBuildScheme() } return "" } // getNamedRoutes returns the map where named routes are registered. func (r *Route) getNamedRoutes() map[string]*Route { if r.parent == nil { // During tests router is not always set. r.parent = NewRouter() } return r.parent.getNamedRoutes() } // getRegexpGroup returns regexp definitions from this route. func (r *Route) getRegexpGroup() *routeRegexpGroup { if r.regexp == nil { if r.parent == nil { // During tests router is not always set. r.parent = NewRouter() } regexp := r.parent.getRegexpGroup() if regexp == nil { r.regexp = new(routeRegexpGroup) } else { // Copy. r.regexp = &routeRegexpGroup{ host: regexp.host, path: regexp.path, queries: regexp.queries, } } } return r.regexp } ================================================ FILE: vendor/github.com/gorilla/mux/test_helpers.go ================================================ // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mux import "net/http" // SetURLVars sets the URL variables for the given request, to be accessed via // mux.Vars for testing route behaviour. Arguments are not modified, a shallow // copy is returned. // // This API should only be used for testing purposes; it provides a way to // inject variables into the request context. Alternatively, URL variables // can be set by making a route that captures the required variables, // starting a server and sending the request to that server. func SetURLVars(r *http.Request, val map[string]string) *http.Request { return setVars(r, val) } ================================================ FILE: vendor/github.com/pmezard/go-difflib/LICENSE ================================================ Copyright (c) 2013, Patrick Mezard All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. The names of its contributors may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: vendor/github.com/pmezard/go-difflib/difflib/difflib.go ================================================ // Package difflib is a partial port of Python difflib module. // // It provides tools to compare sequences of strings and generate textual diffs. // // The following class and functions have been ported: // // - SequenceMatcher // // - unified_diff // // - context_diff // // Getting unified diffs was the main goal of the port. Keep in mind this code // is mostly suitable to output text differences in a human friendly way, there // are no guarantees generated diffs are consumable by patch(1). package difflib import ( "bufio" "bytes" "fmt" "io" "strings" ) func min(a, b int) int { if a < b { return a } return b } func max(a, b int) int { if a > b { return a } return b } func calculateRatio(matches, length int) float64 { if length > 0 { return 2.0 * float64(matches) / float64(length) } return 1.0 } type Match struct { A int B int Size int } type OpCode struct { Tag byte I1 int I2 int J1 int J2 int } // SequenceMatcher compares sequence of strings. The basic // algorithm predates, and is a little fancier than, an algorithm // published in the late 1980's by Ratcliff and Obershelp under the // hyperbolic name "gestalt pattern matching". The basic idea is to find // the longest contiguous matching subsequence that contains no "junk" // elements (R-O doesn't address junk). The same idea is then applied // recursively to the pieces of the sequences to the left and to the right // of the matching subsequence. This does not yield minimal edit // sequences, but does tend to yield matches that "look right" to people. // // SequenceMatcher tries to compute a "human-friendly diff" between two // sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the // longest *contiguous* & junk-free matching subsequence. That's what // catches peoples' eyes. The Windows(tm) windiff has another interesting // notion, pairing up elements that appear uniquely in each sequence. // That, and the method here, appear to yield more intuitive difference // reports than does diff. This method appears to be the least vulnerable // to synching up on blocks of "junk lines", though (like blank lines in // ordinary text files, or maybe "

" lines in HTML files). That may be // because this is the only method of the 3 that has a *concept* of // "junk" . // // Timing: Basic R-O is cubic time worst case and quadratic time expected // case. SequenceMatcher is quadratic time for the worst case and has // expected-case behavior dependent in a complicated way on how many // elements the sequences have in common; best case time is linear. type SequenceMatcher struct { a []string b []string b2j map[string][]int IsJunk func(string) bool autoJunk bool bJunk map[string]struct{} matchingBlocks []Match fullBCount map[string]int bPopular map[string]struct{} opCodes []OpCode } func NewMatcher(a, b []string) *SequenceMatcher { m := SequenceMatcher{autoJunk: true} m.SetSeqs(a, b) return &m } func NewMatcherWithJunk(a, b []string, autoJunk bool, isJunk func(string) bool) *SequenceMatcher { m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk} m.SetSeqs(a, b) return &m } // Set two sequences to be compared. func (m *SequenceMatcher) SetSeqs(a, b []string) { m.SetSeq1(a) m.SetSeq2(b) } // Set the first sequence to be compared. The second sequence to be compared is // not changed. // // SequenceMatcher computes and caches detailed information about the second // sequence, so if you want to compare one sequence S against many sequences, // use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other // sequences. // // See also SetSeqs() and SetSeq2(). func (m *SequenceMatcher) SetSeq1(a []string) { if &a == &m.a { return } m.a = a m.matchingBlocks = nil m.opCodes = nil } // Set the second sequence to be compared. The first sequence to be compared is // not changed. func (m *SequenceMatcher) SetSeq2(b []string) { if &b == &m.b { return } m.b = b m.matchingBlocks = nil m.opCodes = nil m.fullBCount = nil m.chainB() } func (m *SequenceMatcher) chainB() { // Populate line -> index mapping b2j := map[string][]int{} for i, s := range m.b { indices := b2j[s] indices = append(indices, i) b2j[s] = indices } // Purge junk elements m.bJunk = map[string]struct{}{} if m.IsJunk != nil { junk := m.bJunk for s, _ := range b2j { if m.IsJunk(s) { junk[s] = struct{}{} } } for s, _ := range junk { delete(b2j, s) } } // Purge remaining popular elements popular := map[string]struct{}{} n := len(m.b) if m.autoJunk && n >= 200 { ntest := n/100 + 1 for s, indices := range b2j { if len(indices) > ntest { popular[s] = struct{}{} } } for s, _ := range popular { delete(b2j, s) } } m.bPopular = popular m.b2j = b2j } func (m *SequenceMatcher) isBJunk(s string) bool { _, ok := m.bJunk[s] return ok } // Find longest matching block in a[alo:ahi] and b[blo:bhi]. // // If IsJunk is not defined: // // Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where // alo <= i <= i+k <= ahi // blo <= j <= j+k <= bhi // and for all (i',j',k') meeting those conditions, // k >= k' // i <= i' // and if i == i', j <= j' // // In other words, of all maximal matching blocks, return one that // starts earliest in a, and of all those maximal matching blocks that // start earliest in a, return the one that starts earliest in b. // // If IsJunk is defined, first the longest matching block is // determined as above, but with the additional restriction that no // junk element appears in the block. Then that block is extended as // far as possible by matching (only) junk elements on both sides. So // the resulting block never matches on junk except as identical junk // happens to be adjacent to an "interesting" match. // // If no blocks match, return (alo, blo, 0). func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match { // CAUTION: stripping common prefix or suffix would be incorrect. // E.g., // ab // acab // Longest matching block is "ab", but if common prefix is // stripped, it's "a" (tied with "b"). UNIX(tm) diff does so // strip, so ends up claiming that ab is changed to acab by // inserting "ca" in the middle. That's minimal but unintuitive: // "it's obvious" that someone inserted "ac" at the front. // Windiff ends up at the same place as diff, but by pairing up // the unique 'b's and then matching the first two 'a's. besti, bestj, bestsize := alo, blo, 0 // find longest junk-free match // during an iteration of the loop, j2len[j] = length of longest // junk-free match ending with a[i-1] and b[j] j2len := map[int]int{} for i := alo; i != ahi; i++ { // look at all instances of a[i] in b; note that because // b2j has no junk keys, the loop is skipped if a[i] is junk newj2len := map[int]int{} for _, j := range m.b2j[m.a[i]] { // a[i] matches b[j] if j < blo { continue } if j >= bhi { break } k := j2len[j-1] + 1 newj2len[j] = k if k > bestsize { besti, bestj, bestsize = i-k+1, j-k+1, k } } j2len = newj2len } // Extend the best by non-junk elements on each end. In particular, // "popular" non-junk elements aren't in b2j, which greatly speeds // the inner loop above, but also means "the best" match so far // doesn't contain any junk *or* popular non-junk elements. for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) && m.a[besti-1] == m.b[bestj-1] { besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 } for besti+bestsize < ahi && bestj+bestsize < bhi && !m.isBJunk(m.b[bestj+bestsize]) && m.a[besti+bestsize] == m.b[bestj+bestsize] { bestsize += 1 } // Now that we have a wholly interesting match (albeit possibly // empty!), we may as well suck up the matching junk on each // side of it too. Can't think of a good reason not to, and it // saves post-processing the (possibly considerable) expense of // figuring out what to do with it. In the case of an empty // interesting match, this is clearly the right thing to do, // because no other kind of match is possible in the regions. for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) && m.a[besti-1] == m.b[bestj-1] { besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 } for besti+bestsize < ahi && bestj+bestsize < bhi && m.isBJunk(m.b[bestj+bestsize]) && m.a[besti+bestsize] == m.b[bestj+bestsize] { bestsize += 1 } return Match{A: besti, B: bestj, Size: bestsize} } // Return list of triples describing matching subsequences. // // Each triple is of the form (i, j, n), and means that // a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in // i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are // adjacent triples in the list, and the second is not the last triple in the // list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe // adjacent equal blocks. // // The last triple is a dummy, (len(a), len(b), 0), and is the only // triple with n==0. func (m *SequenceMatcher) GetMatchingBlocks() []Match { if m.matchingBlocks != nil { return m.matchingBlocks } var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match { match := m.findLongestMatch(alo, ahi, blo, bhi) i, j, k := match.A, match.B, match.Size if match.Size > 0 { if alo < i && blo < j { matched = matchBlocks(alo, i, blo, j, matched) } matched = append(matched, match) if i+k < ahi && j+k < bhi { matched = matchBlocks(i+k, ahi, j+k, bhi, matched) } } return matched } matched := matchBlocks(0, len(m.a), 0, len(m.b), nil) // It's possible that we have adjacent equal blocks in the // matching_blocks list now. nonAdjacent := []Match{} i1, j1, k1 := 0, 0, 0 for _, b := range matched { // Is this block adjacent to i1, j1, k1? i2, j2, k2 := b.A, b.B, b.Size if i1+k1 == i2 && j1+k1 == j2 { // Yes, so collapse them -- this just increases the length of // the first block by the length of the second, and the first // block so lengthened remains the block to compare against. k1 += k2 } else { // Not adjacent. Remember the first block (k1==0 means it's // the dummy we started with), and make the second block the // new block to compare against. if k1 > 0 { nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) } i1, j1, k1 = i2, j2, k2 } } if k1 > 0 { nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) } nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0}) m.matchingBlocks = nonAdjacent return m.matchingBlocks } // Return list of 5-tuples describing how to turn a into b. // // Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple // has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the // tuple preceding it, and likewise for j1 == the previous j2. // // The tags are characters, with these meanings: // // 'r' (replace): a[i1:i2] should be replaced by b[j1:j2] // // 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case. // // 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case. // // 'e' (equal): a[i1:i2] == b[j1:j2] func (m *SequenceMatcher) GetOpCodes() []OpCode { if m.opCodes != nil { return m.opCodes } i, j := 0, 0 matching := m.GetMatchingBlocks() opCodes := make([]OpCode, 0, len(matching)) for _, m := range matching { // invariant: we've pumped out correct diffs to change // a[:i] into b[:j], and the next matching block is // a[ai:ai+size] == b[bj:bj+size]. So we need to pump // out a diff to change a[i:ai] into b[j:bj], pump out // the matching block, and move (i,j) beyond the match ai, bj, size := m.A, m.B, m.Size tag := byte(0) if i < ai && j < bj { tag = 'r' } else if i < ai { tag = 'd' } else if j < bj { tag = 'i' } if tag > 0 { opCodes = append(opCodes, OpCode{tag, i, ai, j, bj}) } i, j = ai+size, bj+size // the list of matching blocks is terminated by a // sentinel with size 0 if size > 0 { opCodes = append(opCodes, OpCode{'e', ai, i, bj, j}) } } m.opCodes = opCodes return m.opCodes } // Isolate change clusters by eliminating ranges with no changes. // // Return a generator of groups with up to n lines of context. // Each group is in the same format as returned by GetOpCodes(). func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode { if n < 0 { n = 3 } codes := m.GetOpCodes() if len(codes) == 0 { codes = []OpCode{OpCode{'e', 0, 1, 0, 1}} } // Fixup leading and trailing groups if they show no changes. if codes[0].Tag == 'e' { c := codes[0] i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} } if codes[len(codes)-1].Tag == 'e' { c := codes[len(codes)-1] i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} } nn := n + n groups := [][]OpCode{} group := []OpCode{} for _, c := range codes { i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 // End the current group and start a new one whenever // there is a large range with no changes. if c.Tag == 'e' && i2-i1 > nn { group = append(group, OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}) groups = append(groups, group) group = []OpCode{} i1, j1 = max(i1, i2-n), max(j1, j2-n) } group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) } if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') { groups = append(groups, group) } return groups } // Return a measure of the sequences' similarity (float in [0,1]). // // Where T is the total number of elements in both sequences, and // M is the number of matches, this is 2.0*M / T. // Note that this is 1 if the sequences are identical, and 0 if // they have nothing in common. // // .Ratio() is expensive to compute if you haven't already computed // .GetMatchingBlocks() or .GetOpCodes(), in which case you may // want to try .QuickRatio() or .RealQuickRation() first to get an // upper bound. func (m *SequenceMatcher) Ratio() float64 { matches := 0 for _, m := range m.GetMatchingBlocks() { matches += m.Size } return calculateRatio(matches, len(m.a)+len(m.b)) } // Return an upper bound on ratio() relatively quickly. // // This isn't defined beyond that it is an upper bound on .Ratio(), and // is faster to compute. func (m *SequenceMatcher) QuickRatio() float64 { // viewing a and b as multisets, set matches to the cardinality // of their intersection; this counts the number of matches // without regard to order, so is clearly an upper bound if m.fullBCount == nil { m.fullBCount = map[string]int{} for _, s := range m.b { m.fullBCount[s] = m.fullBCount[s] + 1 } } // avail[x] is the number of times x appears in 'b' less the // number of times we've seen it in 'a' so far ... kinda avail := map[string]int{} matches := 0 for _, s := range m.a { n, ok := avail[s] if !ok { n = m.fullBCount[s] } avail[s] = n - 1 if n > 0 { matches += 1 } } return calculateRatio(matches, len(m.a)+len(m.b)) } // Return an upper bound on ratio() very quickly. // // This isn't defined beyond that it is an upper bound on .Ratio(), and // is faster to compute than either .Ratio() or .QuickRatio(). func (m *SequenceMatcher) RealQuickRatio() float64 { la, lb := len(m.a), len(m.b) return calculateRatio(min(la, lb), la+lb) } // Convert range to the "ed" format func formatRangeUnified(start, stop int) string { // Per the diff spec at http://www.unix.org/single_unix_specification/ beginning := start + 1 // lines start numbering with one length := stop - start if length == 1 { return fmt.Sprintf("%d", beginning) } if length == 0 { beginning -= 1 // empty ranges begin at line just before the range } return fmt.Sprintf("%d,%d", beginning, length) } // Unified diff parameters type UnifiedDiff struct { A []string // First sequence lines FromFile string // First file name FromDate string // First file time B []string // Second sequence lines ToFile string // Second file name ToDate string // Second file time Eol string // Headers end of line, defaults to LF Context int // Number of context lines } // Compare two sequences of lines; generate the delta as a unified diff. // // Unified diffs are a compact way of showing line changes and a few // lines of context. The number of context lines is set by 'n' which // defaults to three. // // By default, the diff control lines (those with ---, +++, or @@) are // created with a trailing newline. This is helpful so that inputs // created from file.readlines() result in diffs that are suitable for // file.writelines() since both the inputs and outputs have trailing // newlines. // // For inputs that do not have trailing newlines, set the lineterm // argument to "" so that the output will be uniformly newline free. // // The unidiff format normally has a header for filenames and modification // times. Any or all of these may be specified using strings for // 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'. // The modification times are normally expressed in the ISO 8601 format. func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error { buf := bufio.NewWriter(writer) defer buf.Flush() w := func(format string, args ...interface{}) error { _, err := buf.WriteString(fmt.Sprintf(format, args...)) return err } if len(diff.Eol) == 0 { diff.Eol = "\n" } started := false m := NewMatcher(diff.A, diff.B) for _, g := range m.GetGroupedOpCodes(diff.Context) { if !started { started = true fromDate := "" if len(diff.FromDate) > 0 { fromDate = "\t" + diff.FromDate } toDate := "" if len(diff.ToDate) > 0 { toDate = "\t" + diff.ToDate } err := w("--- %s%s%s", diff.FromFile, fromDate, diff.Eol) if err != nil { return err } err = w("+++ %s%s%s", diff.ToFile, toDate, diff.Eol) if err != nil { return err } } first, last := g[0], g[len(g)-1] range1 := formatRangeUnified(first.I1, last.I2) range2 := formatRangeUnified(first.J1, last.J2) if err := w("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil { return err } for _, c := range g { i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 if c.Tag == 'e' { for _, line := range diff.A[i1:i2] { if err := w(" " + line); err != nil { return err } } continue } if c.Tag == 'r' || c.Tag == 'd' { for _, line := range diff.A[i1:i2] { if err := w("-" + line); err != nil { return err } } } if c.Tag == 'r' || c.Tag == 'i' { for _, line := range diff.B[j1:j2] { if err := w("+" + line); err != nil { return err } } } } } return nil } // Like WriteUnifiedDiff but returns the diff a string. func GetUnifiedDiffString(diff UnifiedDiff) (string, error) { w := &bytes.Buffer{} err := WriteUnifiedDiff(w, diff) return string(w.Bytes()), err } // Convert range to the "ed" format. func formatRangeContext(start, stop int) string { // Per the diff spec at http://www.unix.org/single_unix_specification/ beginning := start + 1 // lines start numbering with one length := stop - start if length == 0 { beginning -= 1 // empty ranges begin at line just before the range } if length <= 1 { return fmt.Sprintf("%d", beginning) } return fmt.Sprintf("%d,%d", beginning, beginning+length-1) } type ContextDiff UnifiedDiff // Compare two sequences of lines; generate the delta as a context diff. // // Context diffs are a compact way of showing line changes and a few // lines of context. The number of context lines is set by diff.Context // which defaults to three. // // By default, the diff control lines (those with *** or ---) are // created with a trailing newline. // // For inputs that do not have trailing newlines, set the diff.Eol // argument to "" so that the output will be uniformly newline free. // // The context diff format normally has a header for filenames and // modification times. Any or all of these may be specified using // strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate. // The modification times are normally expressed in the ISO 8601 format. // If not specified, the strings default to blanks. func WriteContextDiff(writer io.Writer, diff ContextDiff) error { buf := bufio.NewWriter(writer) defer buf.Flush() var diffErr error w := func(format string, args ...interface{}) { _, err := buf.WriteString(fmt.Sprintf(format, args...)) if diffErr == nil && err != nil { diffErr = err } } if len(diff.Eol) == 0 { diff.Eol = "\n" } prefix := map[byte]string{ 'i': "+ ", 'd': "- ", 'r': "! ", 'e': " ", } started := false m := NewMatcher(diff.A, diff.B) for _, g := range m.GetGroupedOpCodes(diff.Context) { if !started { started = true fromDate := "" if len(diff.FromDate) > 0 { fromDate = "\t" + diff.FromDate } toDate := "" if len(diff.ToDate) > 0 { toDate = "\t" + diff.ToDate } w("*** %s%s%s", diff.FromFile, fromDate, diff.Eol) w("--- %s%s%s", diff.ToFile, toDate, diff.Eol) } first, last := g[0], g[len(g)-1] w("***************" + diff.Eol) range1 := formatRangeContext(first.I1, last.I2) w("*** %s ****%s", range1, diff.Eol) for _, c := range g { if c.Tag == 'r' || c.Tag == 'd' { for _, cc := range g { if cc.Tag == 'i' { continue } for _, line := range diff.A[cc.I1:cc.I2] { w(prefix[cc.Tag] + line) } } break } } range2 := formatRangeContext(first.J1, last.J2) w("--- %s ----%s", range2, diff.Eol) for _, c := range g { if c.Tag == 'r' || c.Tag == 'i' { for _, cc := range g { if cc.Tag == 'd' { continue } for _, line := range diff.B[cc.J1:cc.J2] { w(prefix[cc.Tag] + line) } } break } } } return diffErr } // Like WriteContextDiff but returns the diff a string. func GetContextDiffString(diff ContextDiff) (string, error) { w := &bytes.Buffer{} err := WriteContextDiff(w, diff) return string(w.Bytes()), err } // Split a string on "\n" while preserving them. The output can be used // as input for UnifiedDiff and ContextDiff structures. func SplitLines(s string) []string { lines := strings.SplitAfter(s, "\n") lines[len(lines)-1] += "\n" return lines } ================================================ FILE: vendor/github.com/stretchr/objx/LICENSE.md ================================================ objx - by Mat Ryer and Tyler Bunnell The MIT License (MIT) Copyright (c) 2014 Stretchr, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: vendor/github.com/stretchr/objx/README.md ================================================ # objx * Jump into the [API Documentation](http://godoc.org/github.com/stretchr/objx) ================================================ FILE: vendor/github.com/stretchr/objx/accessors.go ================================================ package objx import ( "fmt" "regexp" "strconv" "strings" ) // arrayAccesRegexString is the regex used to extract the array number // from the access path const arrayAccesRegexString = `^(.+)\[([0-9]+)\]$` // arrayAccesRegex is the compiled arrayAccesRegexString var arrayAccesRegex = regexp.MustCompile(arrayAccesRegexString) // Get gets the value using the specified selector and // returns it inside a new Obj object. // // If it cannot find the value, Get will return a nil // value inside an instance of Obj. // // Get can only operate directly on map[string]interface{} and []interface. // // Example // // To access the title of the third chapter of the second book, do: // // o.Get("books[1].chapters[2].title") func (m Map) Get(selector string) *Value { rawObj := access(m, selector, nil, false, false) return &Value{data: rawObj} } // Set sets the value using the specified selector and // returns the object on which Set was called. // // Set can only operate directly on map[string]interface{} and []interface // // Example // // To set the title of the third chapter of the second book, do: // // o.Set("books[1].chapters[2].title","Time to Go") func (m Map) Set(selector string, value interface{}) Map { access(m, selector, value, true, false) return m } // access accesses the object using the selector and performs the // appropriate action. func access(current, selector, value interface{}, isSet, panics bool) interface{} { switch selector.(type) { case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: if array, ok := current.([]interface{}); ok { index := intFromInterface(selector) if index >= len(array) { if panics { panic(fmt.Sprintf("objx: Index %d is out of range. Slice only contains %d items.", index, len(array))) } return nil } return array[index] } return nil case string: selStr := selector.(string) selSegs := strings.SplitN(selStr, PathSeparator, 2) thisSel := selSegs[0] index := -1 var err error // https://github.com/stretchr/objx/issues/12 if strings.Contains(thisSel, "[") { arrayMatches := arrayAccesRegex.FindStringSubmatch(thisSel) if len(arrayMatches) > 0 { // Get the key into the map thisSel = arrayMatches[1] // Get the index into the array at the key index, err = strconv.Atoi(arrayMatches[2]) if err != nil { // This should never happen. If it does, something has gone // seriously wrong. Panic. panic("objx: Array index is not an integer. Must use array[int].") } } } if curMap, ok := current.(Map); ok { current = map[string]interface{}(curMap) } // get the object in question switch current.(type) { case map[string]interface{}: curMSI := current.(map[string]interface{}) if len(selSegs) <= 1 && isSet { curMSI[thisSel] = value return nil } else { current = curMSI[thisSel] } default: current = nil } if current == nil && panics { panic(fmt.Sprintf("objx: '%v' invalid on object.", selector)) } // do we need to access the item of an array? if index > -1 { if array, ok := current.([]interface{}); ok { if index < len(array) { current = array[index] } else { if panics { panic(fmt.Sprintf("objx: Index %d is out of range. Slice only contains %d items.", index, len(array))) } current = nil } } } if len(selSegs) > 1 { current = access(current, selSegs[1], value, isSet, panics) } } return current } // intFromInterface converts an interface object to the largest // representation of an unsigned integer using a type switch and // assertions func intFromInterface(selector interface{}) int { var value int switch selector.(type) { case int: value = selector.(int) case int8: value = int(selector.(int8)) case int16: value = int(selector.(int16)) case int32: value = int(selector.(int32)) case int64: value = int(selector.(int64)) case uint: value = int(selector.(uint)) case uint8: value = int(selector.(uint8)) case uint16: value = int(selector.(uint16)) case uint32: value = int(selector.(uint32)) case uint64: value = int(selector.(uint64)) default: panic("objx: array access argument is not an integer type (this should never happen)") } return value } ================================================ FILE: vendor/github.com/stretchr/objx/constants.go ================================================ package objx const ( // PathSeparator is the character used to separate the elements // of the keypath. // // For example, `location.address.city` PathSeparator string = "." // SignatureSeparator is the character that is used to // separate the Base64 string from the security signature. SignatureSeparator = "_" ) ================================================ FILE: vendor/github.com/stretchr/objx/conversions.go ================================================ package objx import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "net/url" ) // JSON converts the contained object to a JSON string // representation func (m Map) JSON() (string, error) { result, err := json.Marshal(m) if err != nil { err = errors.New("objx: JSON encode failed with: " + err.Error()) } return string(result), err } // MustJSON converts the contained object to a JSON string // representation and panics if there is an error func (m Map) MustJSON() string { result, err := m.JSON() if err != nil { panic(err.Error()) } return result } // Base64 converts the contained object to a Base64 string // representation of the JSON string representation func (m Map) Base64() (string, error) { var buf bytes.Buffer jsonData, err := m.JSON() if err != nil { return "", err } encoder := base64.NewEncoder(base64.StdEncoding, &buf) encoder.Write([]byte(jsonData)) encoder.Close() return buf.String(), nil } // MustBase64 converts the contained object to a Base64 string // representation of the JSON string representation and panics // if there is an error func (m Map) MustBase64() string { result, err := m.Base64() if err != nil { panic(err.Error()) } return result } // SignedBase64 converts the contained object to a Base64 string // representation of the JSON string representation and signs it // using the provided key. func (m Map) SignedBase64(key string) (string, error) { base64, err := m.Base64() if err != nil { return "", err } sig := HashWithKey(base64, key) return base64 + SignatureSeparator + sig, nil } // MustSignedBase64 converts the contained object to a Base64 string // representation of the JSON string representation and signs it // using the provided key and panics if there is an error func (m Map) MustSignedBase64(key string) string { result, err := m.SignedBase64(key) if err != nil { panic(err.Error()) } return result } /* URL Query ------------------------------------------------ */ // URLValues creates a url.Values object from an Obj. This // function requires that the wrapped object be a map[string]interface{} func (m Map) URLValues() url.Values { vals := make(url.Values) for k, v := range m { //TODO: can this be done without sprintf? vals.Set(k, fmt.Sprintf("%v", v)) } return vals } // URLQuery gets an encoded URL query representing the given // Obj. This function requires that the wrapped object be a // map[string]interface{} func (m Map) URLQuery() (string, error) { return m.URLValues().Encode(), nil } ================================================ FILE: vendor/github.com/stretchr/objx/doc.go ================================================ // objx - Go package for dealing with maps, slices, JSON and other data. // // Overview // // Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes // a powerful `Get` method (among others) that allows you to easily and quickly get // access to data within the map, without having to worry too much about type assertions, // missing data, default values etc. // // Pattern // // Objx uses a preditable pattern to make access data from within `map[string]interface{}'s // easy. // // Call one of the `objx.` functions to create your `objx.Map` to get going: // // m, err := objx.FromJSON(json) // // NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, // the rest will be optimistic and try to figure things out without panicking. // // Use `Get` to access the value you're interested in. You can use dot and array // notation too: // // m.Get("places[0].latlng") // // Once you have saught the `Value` you're interested in, you can use the `Is*` methods // to determine its type. // // if m.Get("code").IsStr() { /* ... */ } // // Or you can just assume the type, and use one of the strong type methods to // extract the real value: // // m.Get("code").Int() // // If there's no value there (or if it's the wrong type) then a default value // will be returned, or you can be explicit about the default value. // // Get("code").Int(-1) // // If you're dealing with a slice of data as a value, Objx provides many useful // methods for iterating, manipulating and selecting that data. You can find out more // by exploring the index below. // // Reading data // // A simple example of how to use Objx: // // // use MustFromJSON to make an objx.Map from some JSON // m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`) // // // get the details // name := m.Get("name").Str() // age := m.Get("age").Int() // // // get their nickname (or use their name if they // // don't have one) // nickname := m.Get("nickname").Str(name) // // Ranging // // Since `objx.Map` is a `map[string]interface{}` you can treat it as such. For // example, to `range` the data, do what you would expect: // // m := objx.MustFromJSON(json) // for key, value := range m { // // /* ... do your magic ... */ // // } package objx ================================================ FILE: vendor/github.com/stretchr/objx/map.go ================================================ package objx import ( "encoding/base64" "encoding/json" "errors" "io/ioutil" "net/url" "strings" ) // MSIConvertable is an interface that defines methods for converting your // custom types to a map[string]interface{} representation. type MSIConvertable interface { // MSI gets a map[string]interface{} (msi) representing the // object. MSI() map[string]interface{} } // Map provides extended functionality for working with // untyped data, in particular map[string]interface (msi). type Map map[string]interface{} // Value returns the internal value instance func (m Map) Value() *Value { return &Value{data: m} } // Nil represents a nil Map. var Nil Map = New(nil) // New creates a new Map containing the map[string]interface{} in the data argument. // If the data argument is not a map[string]interface, New attempts to call the // MSI() method on the MSIConvertable interface to create one. func New(data interface{}) Map { if _, ok := data.(map[string]interface{}); !ok { if converter, ok := data.(MSIConvertable); ok { data = converter.MSI() } else { return nil } } return Map(data.(map[string]interface{})) } // MSI creates a map[string]interface{} and puts it inside a new Map. // // The arguments follow a key, value pattern. // // Panics // // Panics if any key arugment is non-string or if there are an odd number of arguments. // // Example // // To easily create Maps: // // m := objx.MSI("name", "Mat", "age", 29, "subobj", objx.MSI("active", true)) // // // creates an Map equivalent to // m := objx.New(map[string]interface{}{"name": "Mat", "age": 29, "subobj": map[string]interface{}{"active": true}}) func MSI(keyAndValuePairs ...interface{}) Map { newMap := make(map[string]interface{}) keyAndValuePairsLen := len(keyAndValuePairs) if keyAndValuePairsLen%2 != 0 { panic("objx: MSI must have an even number of arguments following the 'key, value' pattern.") } for i := 0; i < keyAndValuePairsLen; i = i + 2 { key := keyAndValuePairs[i] value := keyAndValuePairs[i+1] // make sure the key is a string keyString, keyStringOK := key.(string) if !keyStringOK { panic("objx: MSI must follow 'string, interface{}' pattern. " + keyString + " is not a valid key.") } newMap[keyString] = value } return New(newMap) } // ****** Conversion Constructors // MustFromJSON creates a new Map containing the data specified in the // jsonString. // // Panics if the JSON is invalid. func MustFromJSON(jsonString string) Map { o, err := FromJSON(jsonString) if err != nil { panic("objx: MustFromJSON failed with error: " + err.Error()) } return o } // FromJSON creates a new Map containing the data specified in the // jsonString. // // Returns an error if the JSON is invalid. func FromJSON(jsonString string) (Map, error) { var data interface{} err := json.Unmarshal([]byte(jsonString), &data) if err != nil { return Nil, err } return New(data), nil } // FromBase64 creates a new Obj containing the data specified // in the Base64 string. // // The string is an encoded JSON string returned by Base64 func FromBase64(base64String string) (Map, error) { decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64String)) decoded, err := ioutil.ReadAll(decoder) if err != nil { return nil, err } return FromJSON(string(decoded)) } // MustFromBase64 creates a new Obj containing the data specified // in the Base64 string and panics if there is an error. // // The string is an encoded JSON string returned by Base64 func MustFromBase64(base64String string) Map { result, err := FromBase64(base64String) if err != nil { panic("objx: MustFromBase64 failed with error: " + err.Error()) } return result } // FromSignedBase64 creates a new Obj containing the data specified // in the Base64 string. // // The string is an encoded JSON string returned by SignedBase64 func FromSignedBase64(base64String, key string) (Map, error) { parts := strings.Split(base64String, SignatureSeparator) if len(parts) != 2 { return nil, errors.New("objx: Signed base64 string is malformed.") } sig := HashWithKey(parts[0], key) if parts[1] != sig { return nil, errors.New("objx: Signature for base64 data does not match.") } return FromBase64(parts[0]) } // MustFromSignedBase64 creates a new Obj containing the data specified // in the Base64 string and panics if there is an error. // // The string is an encoded JSON string returned by Base64 func MustFromSignedBase64(base64String, key string) Map { result, err := FromSignedBase64(base64String, key) if err != nil { panic("objx: MustFromSignedBase64 failed with error: " + err.Error()) } return result } // FromURLQuery generates a new Obj by parsing the specified // query. // // For queries with multiple values, the first value is selected. func FromURLQuery(query string) (Map, error) { vals, err := url.ParseQuery(query) if err != nil { return nil, err } m := make(map[string]interface{}) for k, vals := range vals { m[k] = vals[0] } return New(m), nil } // MustFromURLQuery generates a new Obj by parsing the specified // query. // // For queries with multiple values, the first value is selected. // // Panics if it encounters an error func MustFromURLQuery(query string) Map { o, err := FromURLQuery(query) if err != nil { panic("objx: MustFromURLQuery failed with error: " + err.Error()) } return o } ================================================ FILE: vendor/github.com/stretchr/objx/mutations.go ================================================ package objx // Exclude returns a new Map with the keys in the specified []string // excluded. func (d Map) Exclude(exclude []string) Map { excluded := make(Map) for k, v := range d { var shouldInclude bool = true for _, toExclude := range exclude { if k == toExclude { shouldInclude = false break } } if shouldInclude { excluded[k] = v } } return excluded } // Copy creates a shallow copy of the Obj. func (m Map) Copy() Map { copied := make(map[string]interface{}) for k, v := range m { copied[k] = v } return New(copied) } // Merge blends the specified map with a copy of this map and returns the result. // // Keys that appear in both will be selected from the specified map. // This method requires that the wrapped object be a map[string]interface{} func (m Map) Merge(merge Map) Map { return m.Copy().MergeHere(merge) } // Merge blends the specified map with this map and returns the current map. // // Keys that appear in both will be selected from the specified map. The original map // will be modified. This method requires that // the wrapped object be a map[string]interface{} func (m Map) MergeHere(merge Map) Map { for k, v := range merge { m[k] = v } return m } // Transform builds a new Obj giving the transformer a chance // to change the keys and values as it goes. This method requires that // the wrapped object be a map[string]interface{} func (m Map) Transform(transformer func(key string, value interface{}) (string, interface{})) Map { newMap := make(map[string]interface{}) for k, v := range m { modifiedKey, modifiedVal := transformer(k, v) newMap[modifiedKey] = modifiedVal } return New(newMap) } // TransformKeys builds a new map using the specified key mapping. // // Unspecified keys will be unaltered. // This method requires that the wrapped object be a map[string]interface{} func (m Map) TransformKeys(mapping map[string]string) Map { return m.Transform(func(key string, value interface{}) (string, interface{}) { if newKey, ok := mapping[key]; ok { return newKey, value } return key, value }) } ================================================ FILE: vendor/github.com/stretchr/objx/security.go ================================================ package objx import ( "crypto/sha1" "encoding/hex" ) // HashWithKey hashes the specified string using the security // key. func HashWithKey(data, key string) string { hash := sha1.New() hash.Write([]byte(data + ":" + key)) return hex.EncodeToString(hash.Sum(nil)) } ================================================ FILE: vendor/github.com/stretchr/objx/tests.go ================================================ package objx // Has gets whether there is something at the specified selector // or not. // // If m is nil, Has will always return false. func (m Map) Has(selector string) bool { if m == nil { return false } return !m.Get(selector).IsNil() } // IsNil gets whether the data is nil or not. func (v *Value) IsNil() bool { return v == nil || v.data == nil } ================================================ FILE: vendor/github.com/stretchr/objx/type_specific_codegen.go ================================================ package objx /* Inter (interface{} and []interface{}) -------------------------------------------------- */ // Inter gets the value as a interface{}, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Inter(optionalDefault ...interface{}) interface{} { if s, ok := v.data.(interface{}); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustInter gets the value as a interface{}. // // Panics if the object is not a interface{}. func (v *Value) MustInter() interface{} { return v.data.(interface{}) } // InterSlice gets the value as a []interface{}, returns the optionalDefault // value or nil if the value is not a []interface{}. func (v *Value) InterSlice(optionalDefault ...[]interface{}) []interface{} { if s, ok := v.data.([]interface{}); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustInterSlice gets the value as a []interface{}. // // Panics if the object is not a []interface{}. func (v *Value) MustInterSlice() []interface{} { return v.data.([]interface{}) } // IsInter gets whether the object contained is a interface{} or not. func (v *Value) IsInter() bool { _, ok := v.data.(interface{}) return ok } // IsInterSlice gets whether the object contained is a []interface{} or not. func (v *Value) IsInterSlice() bool { _, ok := v.data.([]interface{}) return ok } // EachInter calls the specified callback for each object // in the []interface{}. // // Panics if the object is the wrong type. func (v *Value) EachInter(callback func(int, interface{}) bool) *Value { for index, val := range v.MustInterSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereInter uses the specified decider function to select items // from the []interface{}. The object contained in the result will contain // only the selected items. func (v *Value) WhereInter(decider func(int, interface{}) bool) *Value { var selected []interface{} v.EachInter(func(index int, val interface{}) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupInter uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]interface{}. func (v *Value) GroupInter(grouper func(int, interface{}) string) *Value { groups := make(map[string][]interface{}) v.EachInter(func(index int, val interface{}) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]interface{}, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceInter uses the specified function to replace each interface{}s // by iterating each item. The data in the returned result will be a // []interface{} containing the replaced items. func (v *Value) ReplaceInter(replacer func(int, interface{}) interface{}) *Value { arr := v.MustInterSlice() replaced := make([]interface{}, len(arr)) v.EachInter(func(index int, val interface{}) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectInter uses the specified collector function to collect a value // for each of the interface{}s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectInter(collector func(int, interface{}) interface{}) *Value { arr := v.MustInterSlice() collected := make([]interface{}, len(arr)) v.EachInter(func(index int, val interface{}) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* MSI (map[string]interface{} and []map[string]interface{}) -------------------------------------------------- */ // MSI gets the value as a map[string]interface{}, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) MSI(optionalDefault ...map[string]interface{}) map[string]interface{} { if s, ok := v.data.(map[string]interface{}); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustMSI gets the value as a map[string]interface{}. // // Panics if the object is not a map[string]interface{}. func (v *Value) MustMSI() map[string]interface{} { return v.data.(map[string]interface{}) } // MSISlice gets the value as a []map[string]interface{}, returns the optionalDefault // value or nil if the value is not a []map[string]interface{}. func (v *Value) MSISlice(optionalDefault ...[]map[string]interface{}) []map[string]interface{} { if s, ok := v.data.([]map[string]interface{}); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustMSISlice gets the value as a []map[string]interface{}. // // Panics if the object is not a []map[string]interface{}. func (v *Value) MustMSISlice() []map[string]interface{} { return v.data.([]map[string]interface{}) } // IsMSI gets whether the object contained is a map[string]interface{} or not. func (v *Value) IsMSI() bool { _, ok := v.data.(map[string]interface{}) return ok } // IsMSISlice gets whether the object contained is a []map[string]interface{} or not. func (v *Value) IsMSISlice() bool { _, ok := v.data.([]map[string]interface{}) return ok } // EachMSI calls the specified callback for each object // in the []map[string]interface{}. // // Panics if the object is the wrong type. func (v *Value) EachMSI(callback func(int, map[string]interface{}) bool) *Value { for index, val := range v.MustMSISlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereMSI uses the specified decider function to select items // from the []map[string]interface{}. The object contained in the result will contain // only the selected items. func (v *Value) WhereMSI(decider func(int, map[string]interface{}) bool) *Value { var selected []map[string]interface{} v.EachMSI(func(index int, val map[string]interface{}) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupMSI uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]map[string]interface{}. func (v *Value) GroupMSI(grouper func(int, map[string]interface{}) string) *Value { groups := make(map[string][]map[string]interface{}) v.EachMSI(func(index int, val map[string]interface{}) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]map[string]interface{}, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceMSI uses the specified function to replace each map[string]interface{}s // by iterating each item. The data in the returned result will be a // []map[string]interface{} containing the replaced items. func (v *Value) ReplaceMSI(replacer func(int, map[string]interface{}) map[string]interface{}) *Value { arr := v.MustMSISlice() replaced := make([]map[string]interface{}, len(arr)) v.EachMSI(func(index int, val map[string]interface{}) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectMSI uses the specified collector function to collect a value // for each of the map[string]interface{}s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectMSI(collector func(int, map[string]interface{}) interface{}) *Value { arr := v.MustMSISlice() collected := make([]interface{}, len(arr)) v.EachMSI(func(index int, val map[string]interface{}) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* ObjxMap ((Map) and [](Map)) -------------------------------------------------- */ // ObjxMap gets the value as a (Map), returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) ObjxMap(optionalDefault ...(Map)) Map { if s, ok := v.data.((Map)); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return New(nil) } // MustObjxMap gets the value as a (Map). // // Panics if the object is not a (Map). func (v *Value) MustObjxMap() Map { return v.data.((Map)) } // ObjxMapSlice gets the value as a [](Map), returns the optionalDefault // value or nil if the value is not a [](Map). func (v *Value) ObjxMapSlice(optionalDefault ...[](Map)) [](Map) { if s, ok := v.data.([](Map)); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustObjxMapSlice gets the value as a [](Map). // // Panics if the object is not a [](Map). func (v *Value) MustObjxMapSlice() [](Map) { return v.data.([](Map)) } // IsObjxMap gets whether the object contained is a (Map) or not. func (v *Value) IsObjxMap() bool { _, ok := v.data.((Map)) return ok } // IsObjxMapSlice gets whether the object contained is a [](Map) or not. func (v *Value) IsObjxMapSlice() bool { _, ok := v.data.([](Map)) return ok } // EachObjxMap calls the specified callback for each object // in the [](Map). // // Panics if the object is the wrong type. func (v *Value) EachObjxMap(callback func(int, Map) bool) *Value { for index, val := range v.MustObjxMapSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereObjxMap uses the specified decider function to select items // from the [](Map). The object contained in the result will contain // only the selected items. func (v *Value) WhereObjxMap(decider func(int, Map) bool) *Value { var selected [](Map) v.EachObjxMap(func(index int, val Map) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupObjxMap uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][](Map). func (v *Value) GroupObjxMap(grouper func(int, Map) string) *Value { groups := make(map[string][](Map)) v.EachObjxMap(func(index int, val Map) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([](Map), 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceObjxMap uses the specified function to replace each (Map)s // by iterating each item. The data in the returned result will be a // [](Map) containing the replaced items. func (v *Value) ReplaceObjxMap(replacer func(int, Map) Map) *Value { arr := v.MustObjxMapSlice() replaced := make([](Map), len(arr)) v.EachObjxMap(func(index int, val Map) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectObjxMap uses the specified collector function to collect a value // for each of the (Map)s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectObjxMap(collector func(int, Map) interface{}) *Value { arr := v.MustObjxMapSlice() collected := make([]interface{}, len(arr)) v.EachObjxMap(func(index int, val Map) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Bool (bool and []bool) -------------------------------------------------- */ // Bool gets the value as a bool, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Bool(optionalDefault ...bool) bool { if s, ok := v.data.(bool); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return false } // MustBool gets the value as a bool. // // Panics if the object is not a bool. func (v *Value) MustBool() bool { return v.data.(bool) } // BoolSlice gets the value as a []bool, returns the optionalDefault // value or nil if the value is not a []bool. func (v *Value) BoolSlice(optionalDefault ...[]bool) []bool { if s, ok := v.data.([]bool); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustBoolSlice gets the value as a []bool. // // Panics if the object is not a []bool. func (v *Value) MustBoolSlice() []bool { return v.data.([]bool) } // IsBool gets whether the object contained is a bool or not. func (v *Value) IsBool() bool { _, ok := v.data.(bool) return ok } // IsBoolSlice gets whether the object contained is a []bool or not. func (v *Value) IsBoolSlice() bool { _, ok := v.data.([]bool) return ok } // EachBool calls the specified callback for each object // in the []bool. // // Panics if the object is the wrong type. func (v *Value) EachBool(callback func(int, bool) bool) *Value { for index, val := range v.MustBoolSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereBool uses the specified decider function to select items // from the []bool. The object contained in the result will contain // only the selected items. func (v *Value) WhereBool(decider func(int, bool) bool) *Value { var selected []bool v.EachBool(func(index int, val bool) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupBool uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]bool. func (v *Value) GroupBool(grouper func(int, bool) string) *Value { groups := make(map[string][]bool) v.EachBool(func(index int, val bool) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]bool, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceBool uses the specified function to replace each bools // by iterating each item. The data in the returned result will be a // []bool containing the replaced items. func (v *Value) ReplaceBool(replacer func(int, bool) bool) *Value { arr := v.MustBoolSlice() replaced := make([]bool, len(arr)) v.EachBool(func(index int, val bool) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectBool uses the specified collector function to collect a value // for each of the bools in the slice. The data returned will be a // []interface{}. func (v *Value) CollectBool(collector func(int, bool) interface{}) *Value { arr := v.MustBoolSlice() collected := make([]interface{}, len(arr)) v.EachBool(func(index int, val bool) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Str (string and []string) -------------------------------------------------- */ // Str gets the value as a string, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Str(optionalDefault ...string) string { if s, ok := v.data.(string); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return "" } // MustStr gets the value as a string. // // Panics if the object is not a string. func (v *Value) MustStr() string { return v.data.(string) } // StrSlice gets the value as a []string, returns the optionalDefault // value or nil if the value is not a []string. func (v *Value) StrSlice(optionalDefault ...[]string) []string { if s, ok := v.data.([]string); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustStrSlice gets the value as a []string. // // Panics if the object is not a []string. func (v *Value) MustStrSlice() []string { return v.data.([]string) } // IsStr gets whether the object contained is a string or not. func (v *Value) IsStr() bool { _, ok := v.data.(string) return ok } // IsStrSlice gets whether the object contained is a []string or not. func (v *Value) IsStrSlice() bool { _, ok := v.data.([]string) return ok } // EachStr calls the specified callback for each object // in the []string. // // Panics if the object is the wrong type. func (v *Value) EachStr(callback func(int, string) bool) *Value { for index, val := range v.MustStrSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereStr uses the specified decider function to select items // from the []string. The object contained in the result will contain // only the selected items. func (v *Value) WhereStr(decider func(int, string) bool) *Value { var selected []string v.EachStr(func(index int, val string) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupStr uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]string. func (v *Value) GroupStr(grouper func(int, string) string) *Value { groups := make(map[string][]string) v.EachStr(func(index int, val string) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]string, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceStr uses the specified function to replace each strings // by iterating each item. The data in the returned result will be a // []string containing the replaced items. func (v *Value) ReplaceStr(replacer func(int, string) string) *Value { arr := v.MustStrSlice() replaced := make([]string, len(arr)) v.EachStr(func(index int, val string) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectStr uses the specified collector function to collect a value // for each of the strings in the slice. The data returned will be a // []interface{}. func (v *Value) CollectStr(collector func(int, string) interface{}) *Value { arr := v.MustStrSlice() collected := make([]interface{}, len(arr)) v.EachStr(func(index int, val string) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Int (int and []int) -------------------------------------------------- */ // Int gets the value as a int, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Int(optionalDefault ...int) int { if s, ok := v.data.(int); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustInt gets the value as a int. // // Panics if the object is not a int. func (v *Value) MustInt() int { return v.data.(int) } // IntSlice gets the value as a []int, returns the optionalDefault // value or nil if the value is not a []int. func (v *Value) IntSlice(optionalDefault ...[]int) []int { if s, ok := v.data.([]int); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustIntSlice gets the value as a []int. // // Panics if the object is not a []int. func (v *Value) MustIntSlice() []int { return v.data.([]int) } // IsInt gets whether the object contained is a int or not. func (v *Value) IsInt() bool { _, ok := v.data.(int) return ok } // IsIntSlice gets whether the object contained is a []int or not. func (v *Value) IsIntSlice() bool { _, ok := v.data.([]int) return ok } // EachInt calls the specified callback for each object // in the []int. // // Panics if the object is the wrong type. func (v *Value) EachInt(callback func(int, int) bool) *Value { for index, val := range v.MustIntSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereInt uses the specified decider function to select items // from the []int. The object contained in the result will contain // only the selected items. func (v *Value) WhereInt(decider func(int, int) bool) *Value { var selected []int v.EachInt(func(index int, val int) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupInt uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]int. func (v *Value) GroupInt(grouper func(int, int) string) *Value { groups := make(map[string][]int) v.EachInt(func(index int, val int) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]int, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceInt uses the specified function to replace each ints // by iterating each item. The data in the returned result will be a // []int containing the replaced items. func (v *Value) ReplaceInt(replacer func(int, int) int) *Value { arr := v.MustIntSlice() replaced := make([]int, len(arr)) v.EachInt(func(index int, val int) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectInt uses the specified collector function to collect a value // for each of the ints in the slice. The data returned will be a // []interface{}. func (v *Value) CollectInt(collector func(int, int) interface{}) *Value { arr := v.MustIntSlice() collected := make([]interface{}, len(arr)) v.EachInt(func(index int, val int) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Int8 (int8 and []int8) -------------------------------------------------- */ // Int8 gets the value as a int8, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Int8(optionalDefault ...int8) int8 { if s, ok := v.data.(int8); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustInt8 gets the value as a int8. // // Panics if the object is not a int8. func (v *Value) MustInt8() int8 { return v.data.(int8) } // Int8Slice gets the value as a []int8, returns the optionalDefault // value or nil if the value is not a []int8. func (v *Value) Int8Slice(optionalDefault ...[]int8) []int8 { if s, ok := v.data.([]int8); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustInt8Slice gets the value as a []int8. // // Panics if the object is not a []int8. func (v *Value) MustInt8Slice() []int8 { return v.data.([]int8) } // IsInt8 gets whether the object contained is a int8 or not. func (v *Value) IsInt8() bool { _, ok := v.data.(int8) return ok } // IsInt8Slice gets whether the object contained is a []int8 or not. func (v *Value) IsInt8Slice() bool { _, ok := v.data.([]int8) return ok } // EachInt8 calls the specified callback for each object // in the []int8. // // Panics if the object is the wrong type. func (v *Value) EachInt8(callback func(int, int8) bool) *Value { for index, val := range v.MustInt8Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereInt8 uses the specified decider function to select items // from the []int8. The object contained in the result will contain // only the selected items. func (v *Value) WhereInt8(decider func(int, int8) bool) *Value { var selected []int8 v.EachInt8(func(index int, val int8) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupInt8 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]int8. func (v *Value) GroupInt8(grouper func(int, int8) string) *Value { groups := make(map[string][]int8) v.EachInt8(func(index int, val int8) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]int8, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceInt8 uses the specified function to replace each int8s // by iterating each item. The data in the returned result will be a // []int8 containing the replaced items. func (v *Value) ReplaceInt8(replacer func(int, int8) int8) *Value { arr := v.MustInt8Slice() replaced := make([]int8, len(arr)) v.EachInt8(func(index int, val int8) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectInt8 uses the specified collector function to collect a value // for each of the int8s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectInt8(collector func(int, int8) interface{}) *Value { arr := v.MustInt8Slice() collected := make([]interface{}, len(arr)) v.EachInt8(func(index int, val int8) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Int16 (int16 and []int16) -------------------------------------------------- */ // Int16 gets the value as a int16, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Int16(optionalDefault ...int16) int16 { if s, ok := v.data.(int16); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustInt16 gets the value as a int16. // // Panics if the object is not a int16. func (v *Value) MustInt16() int16 { return v.data.(int16) } // Int16Slice gets the value as a []int16, returns the optionalDefault // value or nil if the value is not a []int16. func (v *Value) Int16Slice(optionalDefault ...[]int16) []int16 { if s, ok := v.data.([]int16); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustInt16Slice gets the value as a []int16. // // Panics if the object is not a []int16. func (v *Value) MustInt16Slice() []int16 { return v.data.([]int16) } // IsInt16 gets whether the object contained is a int16 or not. func (v *Value) IsInt16() bool { _, ok := v.data.(int16) return ok } // IsInt16Slice gets whether the object contained is a []int16 or not. func (v *Value) IsInt16Slice() bool { _, ok := v.data.([]int16) return ok } // EachInt16 calls the specified callback for each object // in the []int16. // // Panics if the object is the wrong type. func (v *Value) EachInt16(callback func(int, int16) bool) *Value { for index, val := range v.MustInt16Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereInt16 uses the specified decider function to select items // from the []int16. The object contained in the result will contain // only the selected items. func (v *Value) WhereInt16(decider func(int, int16) bool) *Value { var selected []int16 v.EachInt16(func(index int, val int16) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupInt16 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]int16. func (v *Value) GroupInt16(grouper func(int, int16) string) *Value { groups := make(map[string][]int16) v.EachInt16(func(index int, val int16) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]int16, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceInt16 uses the specified function to replace each int16s // by iterating each item. The data in the returned result will be a // []int16 containing the replaced items. func (v *Value) ReplaceInt16(replacer func(int, int16) int16) *Value { arr := v.MustInt16Slice() replaced := make([]int16, len(arr)) v.EachInt16(func(index int, val int16) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectInt16 uses the specified collector function to collect a value // for each of the int16s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectInt16(collector func(int, int16) interface{}) *Value { arr := v.MustInt16Slice() collected := make([]interface{}, len(arr)) v.EachInt16(func(index int, val int16) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Int32 (int32 and []int32) -------------------------------------------------- */ // Int32 gets the value as a int32, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Int32(optionalDefault ...int32) int32 { if s, ok := v.data.(int32); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustInt32 gets the value as a int32. // // Panics if the object is not a int32. func (v *Value) MustInt32() int32 { return v.data.(int32) } // Int32Slice gets the value as a []int32, returns the optionalDefault // value or nil if the value is not a []int32. func (v *Value) Int32Slice(optionalDefault ...[]int32) []int32 { if s, ok := v.data.([]int32); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustInt32Slice gets the value as a []int32. // // Panics if the object is not a []int32. func (v *Value) MustInt32Slice() []int32 { return v.data.([]int32) } // IsInt32 gets whether the object contained is a int32 or not. func (v *Value) IsInt32() bool { _, ok := v.data.(int32) return ok } // IsInt32Slice gets whether the object contained is a []int32 or not. func (v *Value) IsInt32Slice() bool { _, ok := v.data.([]int32) return ok } // EachInt32 calls the specified callback for each object // in the []int32. // // Panics if the object is the wrong type. func (v *Value) EachInt32(callback func(int, int32) bool) *Value { for index, val := range v.MustInt32Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereInt32 uses the specified decider function to select items // from the []int32. The object contained in the result will contain // only the selected items. func (v *Value) WhereInt32(decider func(int, int32) bool) *Value { var selected []int32 v.EachInt32(func(index int, val int32) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupInt32 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]int32. func (v *Value) GroupInt32(grouper func(int, int32) string) *Value { groups := make(map[string][]int32) v.EachInt32(func(index int, val int32) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]int32, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceInt32 uses the specified function to replace each int32s // by iterating each item. The data in the returned result will be a // []int32 containing the replaced items. func (v *Value) ReplaceInt32(replacer func(int, int32) int32) *Value { arr := v.MustInt32Slice() replaced := make([]int32, len(arr)) v.EachInt32(func(index int, val int32) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectInt32 uses the specified collector function to collect a value // for each of the int32s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectInt32(collector func(int, int32) interface{}) *Value { arr := v.MustInt32Slice() collected := make([]interface{}, len(arr)) v.EachInt32(func(index int, val int32) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Int64 (int64 and []int64) -------------------------------------------------- */ // Int64 gets the value as a int64, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Int64(optionalDefault ...int64) int64 { if s, ok := v.data.(int64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustInt64 gets the value as a int64. // // Panics if the object is not a int64. func (v *Value) MustInt64() int64 { return v.data.(int64) } // Int64Slice gets the value as a []int64, returns the optionalDefault // value or nil if the value is not a []int64. func (v *Value) Int64Slice(optionalDefault ...[]int64) []int64 { if s, ok := v.data.([]int64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustInt64Slice gets the value as a []int64. // // Panics if the object is not a []int64. func (v *Value) MustInt64Slice() []int64 { return v.data.([]int64) } // IsInt64 gets whether the object contained is a int64 or not. func (v *Value) IsInt64() bool { _, ok := v.data.(int64) return ok } // IsInt64Slice gets whether the object contained is a []int64 or not. func (v *Value) IsInt64Slice() bool { _, ok := v.data.([]int64) return ok } // EachInt64 calls the specified callback for each object // in the []int64. // // Panics if the object is the wrong type. func (v *Value) EachInt64(callback func(int, int64) bool) *Value { for index, val := range v.MustInt64Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereInt64 uses the specified decider function to select items // from the []int64. The object contained in the result will contain // only the selected items. func (v *Value) WhereInt64(decider func(int, int64) bool) *Value { var selected []int64 v.EachInt64(func(index int, val int64) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupInt64 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]int64. func (v *Value) GroupInt64(grouper func(int, int64) string) *Value { groups := make(map[string][]int64) v.EachInt64(func(index int, val int64) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]int64, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceInt64 uses the specified function to replace each int64s // by iterating each item. The data in the returned result will be a // []int64 containing the replaced items. func (v *Value) ReplaceInt64(replacer func(int, int64) int64) *Value { arr := v.MustInt64Slice() replaced := make([]int64, len(arr)) v.EachInt64(func(index int, val int64) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectInt64 uses the specified collector function to collect a value // for each of the int64s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectInt64(collector func(int, int64) interface{}) *Value { arr := v.MustInt64Slice() collected := make([]interface{}, len(arr)) v.EachInt64(func(index int, val int64) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Uint (uint and []uint) -------------------------------------------------- */ // Uint gets the value as a uint, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Uint(optionalDefault ...uint) uint { if s, ok := v.data.(uint); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustUint gets the value as a uint. // // Panics if the object is not a uint. func (v *Value) MustUint() uint { return v.data.(uint) } // UintSlice gets the value as a []uint, returns the optionalDefault // value or nil if the value is not a []uint. func (v *Value) UintSlice(optionalDefault ...[]uint) []uint { if s, ok := v.data.([]uint); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustUintSlice gets the value as a []uint. // // Panics if the object is not a []uint. func (v *Value) MustUintSlice() []uint { return v.data.([]uint) } // IsUint gets whether the object contained is a uint or not. func (v *Value) IsUint() bool { _, ok := v.data.(uint) return ok } // IsUintSlice gets whether the object contained is a []uint or not. func (v *Value) IsUintSlice() bool { _, ok := v.data.([]uint) return ok } // EachUint calls the specified callback for each object // in the []uint. // // Panics if the object is the wrong type. func (v *Value) EachUint(callback func(int, uint) bool) *Value { for index, val := range v.MustUintSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereUint uses the specified decider function to select items // from the []uint. The object contained in the result will contain // only the selected items. func (v *Value) WhereUint(decider func(int, uint) bool) *Value { var selected []uint v.EachUint(func(index int, val uint) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupUint uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]uint. func (v *Value) GroupUint(grouper func(int, uint) string) *Value { groups := make(map[string][]uint) v.EachUint(func(index int, val uint) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]uint, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceUint uses the specified function to replace each uints // by iterating each item. The data in the returned result will be a // []uint containing the replaced items. func (v *Value) ReplaceUint(replacer func(int, uint) uint) *Value { arr := v.MustUintSlice() replaced := make([]uint, len(arr)) v.EachUint(func(index int, val uint) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectUint uses the specified collector function to collect a value // for each of the uints in the slice. The data returned will be a // []interface{}. func (v *Value) CollectUint(collector func(int, uint) interface{}) *Value { arr := v.MustUintSlice() collected := make([]interface{}, len(arr)) v.EachUint(func(index int, val uint) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Uint8 (uint8 and []uint8) -------------------------------------------------- */ // Uint8 gets the value as a uint8, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Uint8(optionalDefault ...uint8) uint8 { if s, ok := v.data.(uint8); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustUint8 gets the value as a uint8. // // Panics if the object is not a uint8. func (v *Value) MustUint8() uint8 { return v.data.(uint8) } // Uint8Slice gets the value as a []uint8, returns the optionalDefault // value or nil if the value is not a []uint8. func (v *Value) Uint8Slice(optionalDefault ...[]uint8) []uint8 { if s, ok := v.data.([]uint8); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustUint8Slice gets the value as a []uint8. // // Panics if the object is not a []uint8. func (v *Value) MustUint8Slice() []uint8 { return v.data.([]uint8) } // IsUint8 gets whether the object contained is a uint8 or not. func (v *Value) IsUint8() bool { _, ok := v.data.(uint8) return ok } // IsUint8Slice gets whether the object contained is a []uint8 or not. func (v *Value) IsUint8Slice() bool { _, ok := v.data.([]uint8) return ok } // EachUint8 calls the specified callback for each object // in the []uint8. // // Panics if the object is the wrong type. func (v *Value) EachUint8(callback func(int, uint8) bool) *Value { for index, val := range v.MustUint8Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereUint8 uses the specified decider function to select items // from the []uint8. The object contained in the result will contain // only the selected items. func (v *Value) WhereUint8(decider func(int, uint8) bool) *Value { var selected []uint8 v.EachUint8(func(index int, val uint8) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupUint8 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]uint8. func (v *Value) GroupUint8(grouper func(int, uint8) string) *Value { groups := make(map[string][]uint8) v.EachUint8(func(index int, val uint8) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]uint8, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceUint8 uses the specified function to replace each uint8s // by iterating each item. The data in the returned result will be a // []uint8 containing the replaced items. func (v *Value) ReplaceUint8(replacer func(int, uint8) uint8) *Value { arr := v.MustUint8Slice() replaced := make([]uint8, len(arr)) v.EachUint8(func(index int, val uint8) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectUint8 uses the specified collector function to collect a value // for each of the uint8s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectUint8(collector func(int, uint8) interface{}) *Value { arr := v.MustUint8Slice() collected := make([]interface{}, len(arr)) v.EachUint8(func(index int, val uint8) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Uint16 (uint16 and []uint16) -------------------------------------------------- */ // Uint16 gets the value as a uint16, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Uint16(optionalDefault ...uint16) uint16 { if s, ok := v.data.(uint16); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustUint16 gets the value as a uint16. // // Panics if the object is not a uint16. func (v *Value) MustUint16() uint16 { return v.data.(uint16) } // Uint16Slice gets the value as a []uint16, returns the optionalDefault // value or nil if the value is not a []uint16. func (v *Value) Uint16Slice(optionalDefault ...[]uint16) []uint16 { if s, ok := v.data.([]uint16); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustUint16Slice gets the value as a []uint16. // // Panics if the object is not a []uint16. func (v *Value) MustUint16Slice() []uint16 { return v.data.([]uint16) } // IsUint16 gets whether the object contained is a uint16 or not. func (v *Value) IsUint16() bool { _, ok := v.data.(uint16) return ok } // IsUint16Slice gets whether the object contained is a []uint16 or not. func (v *Value) IsUint16Slice() bool { _, ok := v.data.([]uint16) return ok } // EachUint16 calls the specified callback for each object // in the []uint16. // // Panics if the object is the wrong type. func (v *Value) EachUint16(callback func(int, uint16) bool) *Value { for index, val := range v.MustUint16Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereUint16 uses the specified decider function to select items // from the []uint16. The object contained in the result will contain // only the selected items. func (v *Value) WhereUint16(decider func(int, uint16) bool) *Value { var selected []uint16 v.EachUint16(func(index int, val uint16) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupUint16 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]uint16. func (v *Value) GroupUint16(grouper func(int, uint16) string) *Value { groups := make(map[string][]uint16) v.EachUint16(func(index int, val uint16) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]uint16, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceUint16 uses the specified function to replace each uint16s // by iterating each item. The data in the returned result will be a // []uint16 containing the replaced items. func (v *Value) ReplaceUint16(replacer func(int, uint16) uint16) *Value { arr := v.MustUint16Slice() replaced := make([]uint16, len(arr)) v.EachUint16(func(index int, val uint16) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectUint16 uses the specified collector function to collect a value // for each of the uint16s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectUint16(collector func(int, uint16) interface{}) *Value { arr := v.MustUint16Slice() collected := make([]interface{}, len(arr)) v.EachUint16(func(index int, val uint16) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Uint32 (uint32 and []uint32) -------------------------------------------------- */ // Uint32 gets the value as a uint32, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Uint32(optionalDefault ...uint32) uint32 { if s, ok := v.data.(uint32); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustUint32 gets the value as a uint32. // // Panics if the object is not a uint32. func (v *Value) MustUint32() uint32 { return v.data.(uint32) } // Uint32Slice gets the value as a []uint32, returns the optionalDefault // value or nil if the value is not a []uint32. func (v *Value) Uint32Slice(optionalDefault ...[]uint32) []uint32 { if s, ok := v.data.([]uint32); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustUint32Slice gets the value as a []uint32. // // Panics if the object is not a []uint32. func (v *Value) MustUint32Slice() []uint32 { return v.data.([]uint32) } // IsUint32 gets whether the object contained is a uint32 or not. func (v *Value) IsUint32() bool { _, ok := v.data.(uint32) return ok } // IsUint32Slice gets whether the object contained is a []uint32 or not. func (v *Value) IsUint32Slice() bool { _, ok := v.data.([]uint32) return ok } // EachUint32 calls the specified callback for each object // in the []uint32. // // Panics if the object is the wrong type. func (v *Value) EachUint32(callback func(int, uint32) bool) *Value { for index, val := range v.MustUint32Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereUint32 uses the specified decider function to select items // from the []uint32. The object contained in the result will contain // only the selected items. func (v *Value) WhereUint32(decider func(int, uint32) bool) *Value { var selected []uint32 v.EachUint32(func(index int, val uint32) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupUint32 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]uint32. func (v *Value) GroupUint32(grouper func(int, uint32) string) *Value { groups := make(map[string][]uint32) v.EachUint32(func(index int, val uint32) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]uint32, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceUint32 uses the specified function to replace each uint32s // by iterating each item. The data in the returned result will be a // []uint32 containing the replaced items. func (v *Value) ReplaceUint32(replacer func(int, uint32) uint32) *Value { arr := v.MustUint32Slice() replaced := make([]uint32, len(arr)) v.EachUint32(func(index int, val uint32) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectUint32 uses the specified collector function to collect a value // for each of the uint32s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectUint32(collector func(int, uint32) interface{}) *Value { arr := v.MustUint32Slice() collected := make([]interface{}, len(arr)) v.EachUint32(func(index int, val uint32) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Uint64 (uint64 and []uint64) -------------------------------------------------- */ // Uint64 gets the value as a uint64, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Uint64(optionalDefault ...uint64) uint64 { if s, ok := v.data.(uint64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustUint64 gets the value as a uint64. // // Panics if the object is not a uint64. func (v *Value) MustUint64() uint64 { return v.data.(uint64) } // Uint64Slice gets the value as a []uint64, returns the optionalDefault // value or nil if the value is not a []uint64. func (v *Value) Uint64Slice(optionalDefault ...[]uint64) []uint64 { if s, ok := v.data.([]uint64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustUint64Slice gets the value as a []uint64. // // Panics if the object is not a []uint64. func (v *Value) MustUint64Slice() []uint64 { return v.data.([]uint64) } // IsUint64 gets whether the object contained is a uint64 or not. func (v *Value) IsUint64() bool { _, ok := v.data.(uint64) return ok } // IsUint64Slice gets whether the object contained is a []uint64 or not. func (v *Value) IsUint64Slice() bool { _, ok := v.data.([]uint64) return ok } // EachUint64 calls the specified callback for each object // in the []uint64. // // Panics if the object is the wrong type. func (v *Value) EachUint64(callback func(int, uint64) bool) *Value { for index, val := range v.MustUint64Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereUint64 uses the specified decider function to select items // from the []uint64. The object contained in the result will contain // only the selected items. func (v *Value) WhereUint64(decider func(int, uint64) bool) *Value { var selected []uint64 v.EachUint64(func(index int, val uint64) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupUint64 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]uint64. func (v *Value) GroupUint64(grouper func(int, uint64) string) *Value { groups := make(map[string][]uint64) v.EachUint64(func(index int, val uint64) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]uint64, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceUint64 uses the specified function to replace each uint64s // by iterating each item. The data in the returned result will be a // []uint64 containing the replaced items. func (v *Value) ReplaceUint64(replacer func(int, uint64) uint64) *Value { arr := v.MustUint64Slice() replaced := make([]uint64, len(arr)) v.EachUint64(func(index int, val uint64) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectUint64 uses the specified collector function to collect a value // for each of the uint64s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectUint64(collector func(int, uint64) interface{}) *Value { arr := v.MustUint64Slice() collected := make([]interface{}, len(arr)) v.EachUint64(func(index int, val uint64) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Uintptr (uintptr and []uintptr) -------------------------------------------------- */ // Uintptr gets the value as a uintptr, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Uintptr(optionalDefault ...uintptr) uintptr { if s, ok := v.data.(uintptr); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustUintptr gets the value as a uintptr. // // Panics if the object is not a uintptr. func (v *Value) MustUintptr() uintptr { return v.data.(uintptr) } // UintptrSlice gets the value as a []uintptr, returns the optionalDefault // value or nil if the value is not a []uintptr. func (v *Value) UintptrSlice(optionalDefault ...[]uintptr) []uintptr { if s, ok := v.data.([]uintptr); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustUintptrSlice gets the value as a []uintptr. // // Panics if the object is not a []uintptr. func (v *Value) MustUintptrSlice() []uintptr { return v.data.([]uintptr) } // IsUintptr gets whether the object contained is a uintptr or not. func (v *Value) IsUintptr() bool { _, ok := v.data.(uintptr) return ok } // IsUintptrSlice gets whether the object contained is a []uintptr or not. func (v *Value) IsUintptrSlice() bool { _, ok := v.data.([]uintptr) return ok } // EachUintptr calls the specified callback for each object // in the []uintptr. // // Panics if the object is the wrong type. func (v *Value) EachUintptr(callback func(int, uintptr) bool) *Value { for index, val := range v.MustUintptrSlice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereUintptr uses the specified decider function to select items // from the []uintptr. The object contained in the result will contain // only the selected items. func (v *Value) WhereUintptr(decider func(int, uintptr) bool) *Value { var selected []uintptr v.EachUintptr(func(index int, val uintptr) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupUintptr uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]uintptr. func (v *Value) GroupUintptr(grouper func(int, uintptr) string) *Value { groups := make(map[string][]uintptr) v.EachUintptr(func(index int, val uintptr) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]uintptr, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceUintptr uses the specified function to replace each uintptrs // by iterating each item. The data in the returned result will be a // []uintptr containing the replaced items. func (v *Value) ReplaceUintptr(replacer func(int, uintptr) uintptr) *Value { arr := v.MustUintptrSlice() replaced := make([]uintptr, len(arr)) v.EachUintptr(func(index int, val uintptr) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectUintptr uses the specified collector function to collect a value // for each of the uintptrs in the slice. The data returned will be a // []interface{}. func (v *Value) CollectUintptr(collector func(int, uintptr) interface{}) *Value { arr := v.MustUintptrSlice() collected := make([]interface{}, len(arr)) v.EachUintptr(func(index int, val uintptr) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Float32 (float32 and []float32) -------------------------------------------------- */ // Float32 gets the value as a float32, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Float32(optionalDefault ...float32) float32 { if s, ok := v.data.(float32); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustFloat32 gets the value as a float32. // // Panics if the object is not a float32. func (v *Value) MustFloat32() float32 { return v.data.(float32) } // Float32Slice gets the value as a []float32, returns the optionalDefault // value or nil if the value is not a []float32. func (v *Value) Float32Slice(optionalDefault ...[]float32) []float32 { if s, ok := v.data.([]float32); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustFloat32Slice gets the value as a []float32. // // Panics if the object is not a []float32. func (v *Value) MustFloat32Slice() []float32 { return v.data.([]float32) } // IsFloat32 gets whether the object contained is a float32 or not. func (v *Value) IsFloat32() bool { _, ok := v.data.(float32) return ok } // IsFloat32Slice gets whether the object contained is a []float32 or not. func (v *Value) IsFloat32Slice() bool { _, ok := v.data.([]float32) return ok } // EachFloat32 calls the specified callback for each object // in the []float32. // // Panics if the object is the wrong type. func (v *Value) EachFloat32(callback func(int, float32) bool) *Value { for index, val := range v.MustFloat32Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereFloat32 uses the specified decider function to select items // from the []float32. The object contained in the result will contain // only the selected items. func (v *Value) WhereFloat32(decider func(int, float32) bool) *Value { var selected []float32 v.EachFloat32(func(index int, val float32) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupFloat32 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]float32. func (v *Value) GroupFloat32(grouper func(int, float32) string) *Value { groups := make(map[string][]float32) v.EachFloat32(func(index int, val float32) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]float32, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceFloat32 uses the specified function to replace each float32s // by iterating each item. The data in the returned result will be a // []float32 containing the replaced items. func (v *Value) ReplaceFloat32(replacer func(int, float32) float32) *Value { arr := v.MustFloat32Slice() replaced := make([]float32, len(arr)) v.EachFloat32(func(index int, val float32) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectFloat32 uses the specified collector function to collect a value // for each of the float32s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectFloat32(collector func(int, float32) interface{}) *Value { arr := v.MustFloat32Slice() collected := make([]interface{}, len(arr)) v.EachFloat32(func(index int, val float32) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Float64 (float64 and []float64) -------------------------------------------------- */ // Float64 gets the value as a float64, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Float64(optionalDefault ...float64) float64 { if s, ok := v.data.(float64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustFloat64 gets the value as a float64. // // Panics if the object is not a float64. func (v *Value) MustFloat64() float64 { return v.data.(float64) } // Float64Slice gets the value as a []float64, returns the optionalDefault // value or nil if the value is not a []float64. func (v *Value) Float64Slice(optionalDefault ...[]float64) []float64 { if s, ok := v.data.([]float64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustFloat64Slice gets the value as a []float64. // // Panics if the object is not a []float64. func (v *Value) MustFloat64Slice() []float64 { return v.data.([]float64) } // IsFloat64 gets whether the object contained is a float64 or not. func (v *Value) IsFloat64() bool { _, ok := v.data.(float64) return ok } // IsFloat64Slice gets whether the object contained is a []float64 or not. func (v *Value) IsFloat64Slice() bool { _, ok := v.data.([]float64) return ok } // EachFloat64 calls the specified callback for each object // in the []float64. // // Panics if the object is the wrong type. func (v *Value) EachFloat64(callback func(int, float64) bool) *Value { for index, val := range v.MustFloat64Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereFloat64 uses the specified decider function to select items // from the []float64. The object contained in the result will contain // only the selected items. func (v *Value) WhereFloat64(decider func(int, float64) bool) *Value { var selected []float64 v.EachFloat64(func(index int, val float64) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupFloat64 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]float64. func (v *Value) GroupFloat64(grouper func(int, float64) string) *Value { groups := make(map[string][]float64) v.EachFloat64(func(index int, val float64) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]float64, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceFloat64 uses the specified function to replace each float64s // by iterating each item. The data in the returned result will be a // []float64 containing the replaced items. func (v *Value) ReplaceFloat64(replacer func(int, float64) float64) *Value { arr := v.MustFloat64Slice() replaced := make([]float64, len(arr)) v.EachFloat64(func(index int, val float64) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectFloat64 uses the specified collector function to collect a value // for each of the float64s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectFloat64(collector func(int, float64) interface{}) *Value { arr := v.MustFloat64Slice() collected := make([]interface{}, len(arr)) v.EachFloat64(func(index int, val float64) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Complex64 (complex64 and []complex64) -------------------------------------------------- */ // Complex64 gets the value as a complex64, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Complex64(optionalDefault ...complex64) complex64 { if s, ok := v.data.(complex64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustComplex64 gets the value as a complex64. // // Panics if the object is not a complex64. func (v *Value) MustComplex64() complex64 { return v.data.(complex64) } // Complex64Slice gets the value as a []complex64, returns the optionalDefault // value or nil if the value is not a []complex64. func (v *Value) Complex64Slice(optionalDefault ...[]complex64) []complex64 { if s, ok := v.data.([]complex64); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustComplex64Slice gets the value as a []complex64. // // Panics if the object is not a []complex64. func (v *Value) MustComplex64Slice() []complex64 { return v.data.([]complex64) } // IsComplex64 gets whether the object contained is a complex64 or not. func (v *Value) IsComplex64() bool { _, ok := v.data.(complex64) return ok } // IsComplex64Slice gets whether the object contained is a []complex64 or not. func (v *Value) IsComplex64Slice() bool { _, ok := v.data.([]complex64) return ok } // EachComplex64 calls the specified callback for each object // in the []complex64. // // Panics if the object is the wrong type. func (v *Value) EachComplex64(callback func(int, complex64) bool) *Value { for index, val := range v.MustComplex64Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereComplex64 uses the specified decider function to select items // from the []complex64. The object contained in the result will contain // only the selected items. func (v *Value) WhereComplex64(decider func(int, complex64) bool) *Value { var selected []complex64 v.EachComplex64(func(index int, val complex64) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupComplex64 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]complex64. func (v *Value) GroupComplex64(grouper func(int, complex64) string) *Value { groups := make(map[string][]complex64) v.EachComplex64(func(index int, val complex64) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]complex64, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceComplex64 uses the specified function to replace each complex64s // by iterating each item. The data in the returned result will be a // []complex64 containing the replaced items. func (v *Value) ReplaceComplex64(replacer func(int, complex64) complex64) *Value { arr := v.MustComplex64Slice() replaced := make([]complex64, len(arr)) v.EachComplex64(func(index int, val complex64) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectComplex64 uses the specified collector function to collect a value // for each of the complex64s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectComplex64(collector func(int, complex64) interface{}) *Value { arr := v.MustComplex64Slice() collected := make([]interface{}, len(arr)) v.EachComplex64(func(index int, val complex64) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } /* Complex128 (complex128 and []complex128) -------------------------------------------------- */ // Complex128 gets the value as a complex128, returns the optionalDefault // value or a system default object if the value is the wrong type. func (v *Value) Complex128(optionalDefault ...complex128) complex128 { if s, ok := v.data.(complex128); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return 0 } // MustComplex128 gets the value as a complex128. // // Panics if the object is not a complex128. func (v *Value) MustComplex128() complex128 { return v.data.(complex128) } // Complex128Slice gets the value as a []complex128, returns the optionalDefault // value or nil if the value is not a []complex128. func (v *Value) Complex128Slice(optionalDefault ...[]complex128) []complex128 { if s, ok := v.data.([]complex128); ok { return s } if len(optionalDefault) == 1 { return optionalDefault[0] } return nil } // MustComplex128Slice gets the value as a []complex128. // // Panics if the object is not a []complex128. func (v *Value) MustComplex128Slice() []complex128 { return v.data.([]complex128) } // IsComplex128 gets whether the object contained is a complex128 or not. func (v *Value) IsComplex128() bool { _, ok := v.data.(complex128) return ok } // IsComplex128Slice gets whether the object contained is a []complex128 or not. func (v *Value) IsComplex128Slice() bool { _, ok := v.data.([]complex128) return ok } // EachComplex128 calls the specified callback for each object // in the []complex128. // // Panics if the object is the wrong type. func (v *Value) EachComplex128(callback func(int, complex128) bool) *Value { for index, val := range v.MustComplex128Slice() { carryon := callback(index, val) if carryon == false { break } } return v } // WhereComplex128 uses the specified decider function to select items // from the []complex128. The object contained in the result will contain // only the selected items. func (v *Value) WhereComplex128(decider func(int, complex128) bool) *Value { var selected []complex128 v.EachComplex128(func(index int, val complex128) bool { shouldSelect := decider(index, val) if shouldSelect == false { selected = append(selected, val) } return true }) return &Value{data: selected} } // GroupComplex128 uses the specified grouper function to group the items // keyed by the return of the grouper. The object contained in the // result will contain a map[string][]complex128. func (v *Value) GroupComplex128(grouper func(int, complex128) string) *Value { groups := make(map[string][]complex128) v.EachComplex128(func(index int, val complex128) bool { group := grouper(index, val) if _, ok := groups[group]; !ok { groups[group] = make([]complex128, 0) } groups[group] = append(groups[group], val) return true }) return &Value{data: groups} } // ReplaceComplex128 uses the specified function to replace each complex128s // by iterating each item. The data in the returned result will be a // []complex128 containing the replaced items. func (v *Value) ReplaceComplex128(replacer func(int, complex128) complex128) *Value { arr := v.MustComplex128Slice() replaced := make([]complex128, len(arr)) v.EachComplex128(func(index int, val complex128) bool { replaced[index] = replacer(index, val) return true }) return &Value{data: replaced} } // CollectComplex128 uses the specified collector function to collect a value // for each of the complex128s in the slice. The data returned will be a // []interface{}. func (v *Value) CollectComplex128(collector func(int, complex128) interface{}) *Value { arr := v.MustComplex128Slice() collected := make([]interface{}, len(arr)) v.EachComplex128(func(index int, val complex128) bool { collected[index] = collector(index, val) return true }) return &Value{data: collected} } ================================================ FILE: vendor/github.com/stretchr/objx/value.go ================================================ package objx // Value provides methods for extracting interface{} data in various // types. type Value struct { // data contains the raw data being managed by this Value data interface{} } // Data returns the raw data contained by this Value func (v *Value) Data() interface{} { return v.data } ================================================ FILE: vendor/github.com/stretchr/testify/LICENSE ================================================ Copyright (c) 2012 - 2013 Mat Ryer and Tyler Bunnell Please consider promoting this project if you find it useful. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: vendor/github.com/stretchr/testify/assert/assertion_format.go ================================================ /* * CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen * THIS FILE MUST NOT BE EDITED BY HAND */ package assert import ( http "net/http" url "net/url" time "time" ) // Conditionf uses a Comparison to assert a complex condition. func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bool { return Condition(t, comp, append([]interface{}{msg}, args...)...) } // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") // assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") // assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { return Contains(t, s, contains, append([]interface{}{msg}, args...)...) } // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // assert.Emptyf(t, obj, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { return Empty(t, object, append([]interface{}{msg}, args...)...) } // Equalf asserts that two objects are equal. // // assert.Equalf(t, 123, 123, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return Equal(t, expected, actual, append([]interface{}{msg}, args...)...) } // EqualErrorf asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool { return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...) } // EqualValuesf asserts that two objects are equal or convertable to the same types // and equal. // // assert.EqualValuesf(t, uint32(123, "error message %s", "formatted"), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) } // Errorf asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if assert.Errorf(t, err, "error message %s", "formatted") { // assert.Equal(t, expectedErrorf, err) // } // // Returns whether the assertion was successful (true) or not (false). func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { return Error(t, err, append([]interface{}{msg}, args...)...) } // Exactlyf asserts that two objects are equal is value and type. // // assert.Exactlyf(t, int32(123, "error message %s", "formatted"), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return Exactly(t, expected, actual, append([]interface{}{msg}, args...)...) } // Failf reports a failure through func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { return Fail(t, failureMessage, append([]interface{}{msg}, args...)...) } // FailNowf fails test func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...) } // Falsef asserts that the specified value is false. // // assert.Falsef(t, myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool { return False(t, value, append([]interface{}{msg}, args...)...) } // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // // assert.HTTPBodyContainsf(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { return HTTPBodyContains(t, handler, method, url, values, str) } // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // // assert.HTTPBodyNotContainsf(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { return HTTPBodyNotContains(t, handler, method, url, values, str) } // HTTPErrorf asserts that a specified handler returns an error status code. // // assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPError(t, handler, method, url, values) } // HTTPRedirectf asserts that a specified handler returns a redirect status code. // // assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPRedirect(t, handler, method, url, values) } // HTTPSuccessf asserts that a specified handler returns a success status code. // // assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPSuccess(t, handler, method, url, values) } // Implementsf asserts that an object is implemented by the specified interface. // // assert.Implementsf(t, (*MyInterface, "error message %s", "formatted")(nil), new(MyObject)) func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { return Implements(t, interfaceObject, object, append([]interface{}{msg}, args...)...) } // InDeltaf asserts that the two numerals are within delta of each other. // // assert.InDeltaf(t, math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01) // // Returns whether the assertion was successful (true) or not (false). func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...) } // InDeltaSlicef is the same as InDelta, except it compares two slices. func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { return InDeltaSlice(t, expected, actual, delta, append([]interface{}{msg}, args...)...) } // InEpsilonf asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { return InEpsilon(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...) } // InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { return InEpsilonSlice(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...) } // IsTypef asserts that the specified objects are of the same type. func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...) } // JSONEqf asserts that two JSON strings are equivalent. // // assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool { return JSONEq(t, expected, actual, append([]interface{}{msg}, args...)...) } // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // // assert.Lenf(t, mySlice, 3, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool { return Len(t, object, length, append([]interface{}{msg}, args...)...) } // Nilf asserts that the specified object is nil. // // assert.Nilf(t, err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { return Nil(t, object, append([]interface{}{msg}, args...)...) } // NoErrorf asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if assert.NoErrorf(t, err, "error message %s", "formatted") { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool { return NoError(t, err, append([]interface{}{msg}, args...)...) } // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") // assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") // assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) } // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if assert.NotEmptyf(t, obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { return NotEmpty(t, object, append([]interface{}{msg}, args...)...) } // NotEqualf asserts that the specified values are NOT equal. // // assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...) } // NotNilf asserts that the specified object is not nil. // // assert.NotNilf(t, err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { return NotNil(t, object, append([]interface{}{msg}, args...)...) } // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // // assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { return NotPanics(t, f, append([]interface{}{msg}, args...)...) } // NotRegexpf asserts that a specified regexp does not match a string. // // assert.NotRegexpf(t, regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting") // assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { return NotRegexp(t, rx, str, append([]interface{}{msg}, args...)...) } // NotSubsetf asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { return NotSubset(t, list, subset, append([]interface{}{msg}, args...)...) } // NotZerof asserts that i is not the zero value for its type and returns the truth. func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { return NotZero(t, i, append([]interface{}{msg}, args...)...) } // Panicsf asserts that the code inside the specified PanicTestFunc panics. // // assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { return Panics(t, f, append([]interface{}{msg}, args...)...) } // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { return PanicsWithValue(t, expected, f, append([]interface{}{msg}, args...)...) } // Regexpf asserts that a specified regexp matches a string. // // assert.Regexpf(t, regexp.MustCompile("start", "error message %s", "formatted"), "it's starting") // assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { return Regexp(t, rx, str, append([]interface{}{msg}, args...)...) } // Subsetf asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { return Subset(t, list, subset, append([]interface{}{msg}, args...)...) } // Truef asserts that the specified value is true. // // assert.Truef(t, myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Truef(t TestingT, value bool, msg string, args ...interface{}) bool { return True(t, value, append([]interface{}{msg}, args...)...) } // WithinDurationf asserts that the two times are within duration delta of each other. // // assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...) } // Zerof asserts that i is the zero value for its type and returns the truth. func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { return Zero(t, i, append([]interface{}{msg}, args...)...) } ================================================ FILE: vendor/github.com/stretchr/testify/assert/assertion_format.go.tmpl ================================================ {{.CommentFormat}} func {{.DocInfo.Name}}f(t TestingT, {{.ParamsFormat}}) bool { return {{.DocInfo.Name}}(t, {{.ForwardedParamsFormat}}) } ================================================ FILE: vendor/github.com/stretchr/testify/assert/assertion_forward.go ================================================ /* * CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen * THIS FILE MUST NOT BE EDITED BY HAND */ package assert import ( http "net/http" url "net/url" time "time" ) // Condition uses a Comparison to assert a complex condition. func (a *Assertions) Condition(comp Comparison, msgAndArgs ...interface{}) bool { return Condition(a.t, comp, msgAndArgs...) } // Conditionf uses a Comparison to assert a complex condition. func (a *Assertions) Conditionf(comp Comparison, msg string, args ...interface{}) bool { return Conditionf(a.t, comp, msg, args...) } // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // a.Contains("Hello World", "World") // a.Contains(["Hello", "World"], "World") // a.Contains({"Hello": "World"}, "Hello") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { return Contains(a.t, s, contains, msgAndArgs...) } // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // a.Containsf("Hello World", "World", "error message %s", "formatted") // a.Containsf(["Hello", "World"], "World", "error message %s", "formatted") // a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { return Containsf(a.t, s, contains, msg, args...) } // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // a.Empty(obj) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { return Empty(a.t, object, msgAndArgs...) } // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // a.Emptyf(obj, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool { return Emptyf(a.t, object, msg, args...) } // Equal asserts that two objects are equal. // // a.Equal(123, 123) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { return Equal(a.t, expected, actual, msgAndArgs...) } // EqualError asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // a.EqualError(err, expectedErrorString) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) bool { return EqualError(a.t, theError, errString, msgAndArgs...) } // EqualErrorf asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) bool { return EqualErrorf(a.t, theError, errString, msg, args...) } // EqualValues asserts that two objects are equal or convertable to the same types // and equal. // // a.EqualValues(uint32(123), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { return EqualValues(a.t, expected, actual, msgAndArgs...) } // EqualValuesf asserts that two objects are equal or convertable to the same types // and equal. // // a.EqualValuesf(uint32(123, "error message %s", "formatted"), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return EqualValuesf(a.t, expected, actual, msg, args...) } // Equalf asserts that two objects are equal. // // a.Equalf(123, 123, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return Equalf(a.t, expected, actual, msg, args...) } // Error asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if a.Error(err) { // assert.Equal(t, expectedError, err) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { return Error(a.t, err, msgAndArgs...) } // Errorf asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if a.Errorf(err, "error message %s", "formatted") { // assert.Equal(t, expectedErrorf, err) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { return Errorf(a.t, err, msg, args...) } // Exactly asserts that two objects are equal is value and type. // // a.Exactly(int32(123), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { return Exactly(a.t, expected, actual, msgAndArgs...) } // Exactlyf asserts that two objects are equal is value and type. // // a.Exactlyf(int32(123, "error message %s", "formatted"), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return Exactlyf(a.t, expected, actual, msg, args...) } // Fail reports a failure through func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) bool { return Fail(a.t, failureMessage, msgAndArgs...) } // FailNow fails test func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) bool { return FailNow(a.t, failureMessage, msgAndArgs...) } // FailNowf fails test func (a *Assertions) FailNowf(failureMessage string, msg string, args ...interface{}) bool { return FailNowf(a.t, failureMessage, msg, args...) } // Failf reports a failure through func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{}) bool { return Failf(a.t, failureMessage, msg, args...) } // False asserts that the specified value is false. // // a.False(myBool) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool { return False(a.t, value, msgAndArgs...) } // Falsef asserts that the specified value is false. // // a.Falsef(myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) bool { return Falsef(a.t, value, msg, args...) } // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // // a.HTTPBodyContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { return HTTPBodyContains(a.t, handler, method, url, values, str) } // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // // a.HTTPBodyContainsf(myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { return HTTPBodyContainsf(a.t, handler, method, url, values, str) } // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // // a.HTTPBodyNotContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { return HTTPBodyNotContains(a.t, handler, method, url, values, str) } // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // // a.HTTPBodyNotContainsf(myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { return HTTPBodyNotContainsf(a.t, handler, method, url, values, str) } // HTTPError asserts that a specified handler returns an error status code. // // a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPError(a.t, handler, method, url, values) } // HTTPErrorf asserts that a specified handler returns an error status code. // // a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPErrorf(a.t, handler, method, url, values) } // HTTPRedirect asserts that a specified handler returns a redirect status code. // // a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPRedirect(a.t, handler, method, url, values) } // HTTPRedirectf asserts that a specified handler returns a redirect status code. // // a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPRedirectf(a.t, handler, method, url, values) } // HTTPSuccess asserts that a specified handler returns a success status code. // // a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPSuccess(a.t, handler, method, url, values) } // HTTPSuccessf asserts that a specified handler returns a success status code. // // a.HTTPSuccessf(myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values) bool { return HTTPSuccessf(a.t, handler, method, url, values) } // Implements asserts that an object is implemented by the specified interface. // // a.Implements((*MyInterface)(nil), new(MyObject)) func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { return Implements(a.t, interfaceObject, object, msgAndArgs...) } // Implementsf asserts that an object is implemented by the specified interface. // // a.Implementsf((*MyInterface, "error message %s", "formatted")(nil), new(MyObject)) func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { return Implementsf(a.t, interfaceObject, object, msg, args...) } // InDelta asserts that the two numerals are within delta of each other. // // a.InDelta(math.Pi, (22 / 7.0), 0.01) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { return InDelta(a.t, expected, actual, delta, msgAndArgs...) } // InDeltaSlice is the same as InDelta, except it compares two slices. func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { return InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...) } // InDeltaSlicef is the same as InDelta, except it compares two slices. func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { return InDeltaSlicef(a.t, expected, actual, delta, msg, args...) } // InDeltaf asserts that the two numerals are within delta of each other. // // a.InDeltaf(math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { return InDeltaf(a.t, expected, actual, delta, msg, args...) } // InEpsilon asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { return InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...) } // InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { return InEpsilonSlice(a.t, expected, actual, epsilon, msgAndArgs...) } // InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. func (a *Assertions) InEpsilonSlicef(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { return InEpsilonSlicef(a.t, expected, actual, epsilon, msg, args...) } // InEpsilonf asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { return InEpsilonf(a.t, expected, actual, epsilon, msg, args...) } // IsType asserts that the specified objects are of the same type. func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { return IsType(a.t, expectedType, object, msgAndArgs...) } // IsTypef asserts that the specified objects are of the same type. func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { return IsTypef(a.t, expectedType, object, msg, args...) } // JSONEq asserts that two JSON strings are equivalent. // // a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) bool { return JSONEq(a.t, expected, actual, msgAndArgs...) } // JSONEqf asserts that two JSON strings are equivalent. // // a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) bool { return JSONEqf(a.t, expected, actual, msg, args...) } // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // // a.Len(mySlice, 3) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) bool { return Len(a.t, object, length, msgAndArgs...) } // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // // a.Lenf(mySlice, 3, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) bool { return Lenf(a.t, object, length, msg, args...) } // Nil asserts that the specified object is nil. // // a.Nil(err) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool { return Nil(a.t, object, msgAndArgs...) } // Nilf asserts that the specified object is nil. // // a.Nilf(err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) bool { return Nilf(a.t, object, msg, args...) } // NoError asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if a.NoError(err) { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool { return NoError(a.t, err, msgAndArgs...) } // NoErrorf asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if a.NoErrorf(err, "error message %s", "formatted") { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool { return NoErrorf(a.t, err, msg, args...) } // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // a.NotContains("Hello World", "Earth") // a.NotContains(["Hello", "World"], "Earth") // a.NotContains({"Hello": "World"}, "Earth") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { return NotContains(a.t, s, contains, msgAndArgs...) } // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // a.NotContainsf("Hello World", "Earth", "error message %s", "formatted") // a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted") // a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { return NotContainsf(a.t, s, contains, msg, args...) } // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if a.NotEmpty(obj) { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) bool { return NotEmpty(a.t, object, msgAndArgs...) } // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if a.NotEmptyf(obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) bool { return NotEmptyf(a.t, object, msg, args...) } // NotEqual asserts that the specified values are NOT equal. // // a.NotEqual(obj1, obj2) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { return NotEqual(a.t, expected, actual, msgAndArgs...) } // NotEqualf asserts that the specified values are NOT equal. // // a.NotEqualf(obj1, obj2, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { return NotEqualf(a.t, expected, actual, msg, args...) } // NotNil asserts that the specified object is not nil. // // a.NotNil(err) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool { return NotNil(a.t, object, msgAndArgs...) } // NotNilf asserts that the specified object is not nil. // // a.NotNilf(err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) bool { return NotNilf(a.t, object, msg, args...) } // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // // a.NotPanics(func(){ RemainCalm() }) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool { return NotPanics(a.t, f, msgAndArgs...) } // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // // a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{}) bool { return NotPanicsf(a.t, f, msg, args...) } // NotRegexp asserts that a specified regexp does not match a string. // // a.NotRegexp(regexp.MustCompile("starts"), "it's starting") // a.NotRegexp("^start", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { return NotRegexp(a.t, rx, str, msgAndArgs...) } // NotRegexpf asserts that a specified regexp does not match a string. // // a.NotRegexpf(regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting") // a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { return NotRegexpf(a.t, rx, str, msg, args...) } // NotSubset asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { return NotSubset(a.t, list, subset, msgAndArgs...) } // NotSubsetf asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { return NotSubsetf(a.t, list, subset, msg, args...) } // NotZero asserts that i is not the zero value for its type and returns the truth. func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) bool { return NotZero(a.t, i, msgAndArgs...) } // NotZerof asserts that i is not the zero value for its type and returns the truth. func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) bool { return NotZerof(a.t, i, msg, args...) } // Panics asserts that the code inside the specified PanicTestFunc panics. // // a.Panics(func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool { return Panics(a.t, f, msgAndArgs...) } // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // a.PanicsWithValue("crazy error", func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { return PanicsWithValue(a.t, expected, f, msgAndArgs...) } // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { return PanicsWithValuef(a.t, expected, f, msg, args...) } // Panicsf asserts that the code inside the specified PanicTestFunc panics. // // a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) bool { return Panicsf(a.t, f, msg, args...) } // Regexp asserts that a specified regexp matches a string. // // a.Regexp(regexp.MustCompile("start"), "it's starting") // a.Regexp("start...$", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { return Regexp(a.t, rx, str, msgAndArgs...) } // Regexpf asserts that a specified regexp matches a string. // // a.Regexpf(regexp.MustCompile("start", "error message %s", "formatted"), "it's starting") // a.Regexpf("start...$", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { return Regexpf(a.t, rx, str, msg, args...) } // Subset asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { return Subset(a.t, list, subset, msgAndArgs...) } // Subsetf asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { return Subsetf(a.t, list, subset, msg, args...) } // True asserts that the specified value is true. // // a.True(myBool) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool { return True(a.t, value, msgAndArgs...) } // Truef asserts that the specified value is true. // // a.Truef(myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool { return Truef(a.t, value, msg, args...) } // WithinDuration asserts that the two times are within duration delta of each other. // // a.WithinDuration(time.Now(), time.Now(), 10*time.Second) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { return WithinDuration(a.t, expected, actual, delta, msgAndArgs...) } // WithinDurationf asserts that the two times are within duration delta of each other. // // a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { return WithinDurationf(a.t, expected, actual, delta, msg, args...) } // Zero asserts that i is the zero value for its type and returns the truth. func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) bool { return Zero(a.t, i, msgAndArgs...) } // Zerof asserts that i is the zero value for its type and returns the truth. func (a *Assertions) Zerof(i interface{}, msg string, args ...interface{}) bool { return Zerof(a.t, i, msg, args...) } ================================================ FILE: vendor/github.com/stretchr/testify/assert/assertion_forward.go.tmpl ================================================ {{.CommentWithoutT "a"}} func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) bool { return {{.DocInfo.Name}}(a.t, {{.ForwardedParams}}) } ================================================ FILE: vendor/github.com/stretchr/testify/assert/assertions.go ================================================ package assert import ( "bufio" "bytes" "encoding/json" "errors" "fmt" "math" "reflect" "regexp" "runtime" "strings" "time" "unicode" "unicode/utf8" "github.com/davecgh/go-spew/spew" "github.com/pmezard/go-difflib/difflib" ) //go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_format.go.tmpl // TestingT is an interface wrapper around *testing.T type TestingT interface { Errorf(format string, args ...interface{}) } // Comparison a custom function that returns true on success and false on failure type Comparison func() (success bool) /* Helper functions */ // ObjectsAreEqual determines if two objects are considered equal. // // This function does no assertion of any kind. func ObjectsAreEqual(expected, actual interface{}) bool { if expected == nil || actual == nil { return expected == actual } if exp, ok := expected.([]byte); ok { act, ok := actual.([]byte) if !ok { return false } else if exp == nil || act == nil { return exp == nil && act == nil } return bytes.Equal(exp, act) } return reflect.DeepEqual(expected, actual) } // ObjectsAreEqualValues gets whether two objects are equal, or if their // values are equal. func ObjectsAreEqualValues(expected, actual interface{}) bool { if ObjectsAreEqual(expected, actual) { return true } actualType := reflect.TypeOf(actual) if actualType == nil { return false } expectedValue := reflect.ValueOf(expected) if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { // Attempt comparison after type conversion return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) } return false } /* CallerInfo is necessary because the assert functions use the testing object internally, causing it to print the file:line of the assert method, rather than where the problem actually occurred in calling code.*/ // CallerInfo returns an array of strings containing the file and line number // of each stack frame leading from the current test to the assert call that // failed. func CallerInfo() []string { pc := uintptr(0) file := "" line := 0 ok := false name := "" callers := []string{} for i := 0; ; i++ { pc, file, line, ok = runtime.Caller(i) if !ok { // The breaks below failed to terminate the loop, and we ran off the // end of the call stack. break } // This is a huge edge case, but it will panic if this is the case, see #180 if file == "" { break } f := runtime.FuncForPC(pc) if f == nil { break } name = f.Name() // testing.tRunner is the standard library function that calls // tests. Subtests are called directly by tRunner, without going through // the Test/Benchmark/Example function that contains the t.Run calls, so // with subtests we should break when we hit tRunner, without adding it // to the list of callers. if name == "testing.tRunner" { break } parts := strings.Split(file, "/") file = parts[len(parts)-1] if len(parts) > 1 { dir := parts[len(parts)-2] if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" { callers = append(callers, fmt.Sprintf("%s:%d", file, line)) } } // Drop the package segments := strings.Split(name, ".") name = segments[len(segments)-1] if isTest(name, "Test") || isTest(name, "Benchmark") || isTest(name, "Example") { break } } return callers } // Stolen from the `go test` tool. // isTest tells whether name looks like a test (or benchmark, according to prefix). // It is a Test (say) if there is a character after Test that is not a lower-case letter. // We don't want TesticularCancer. func isTest(name, prefix string) bool { if !strings.HasPrefix(name, prefix) { return false } if len(name) == len(prefix) { // "Test" is ok return true } rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) return !unicode.IsLower(rune) } // getWhitespaceString returns a string that is long enough to overwrite the default // output from the go testing framework. func getWhitespaceString() string { _, file, line, ok := runtime.Caller(1) if !ok { return "" } parts := strings.Split(file, "/") file = parts[len(parts)-1] return strings.Repeat(" ", len(fmt.Sprintf("%s:%d: ", file, line))) } func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { if len(msgAndArgs) == 0 || msgAndArgs == nil { return "" } if len(msgAndArgs) == 1 { return msgAndArgs[0].(string) } if len(msgAndArgs) > 1 { return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) } return "" } // Aligns the provided message so that all lines after the first line start at the same location as the first line. // Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab). // The longestLabelLen parameter specifies the length of the longest label in the output (required becaues this is the // basis on which the alignment occurs). func indentMessageLines(message string, longestLabelLen int) string { outBuf := new(bytes.Buffer) for i, scanner := 0, bufio.NewScanner(strings.NewReader(message)); scanner.Scan(); i++ { // no need to align first line because it starts at the correct location (after the label) if i != 0 { // append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab outBuf.WriteString("\n\r\t" + strings.Repeat(" ", longestLabelLen+1) + "\t") } outBuf.WriteString(scanner.Text()) } return outBuf.String() } type failNower interface { FailNow() } // FailNow fails test func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { Fail(t, failureMessage, msgAndArgs...) // We cannot extend TestingT with FailNow() and // maintain backwards compatibility, so we fallback // to panicking when FailNow is not available in // TestingT. // See issue #263 if t, ok := t.(failNower); ok { t.FailNow() } else { panic("test failed and t is missing `FailNow()`") } return false } // Fail reports a failure through func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { content := []labeledContent{ {"Error Trace", strings.Join(CallerInfo(), "\n\r\t\t\t")}, {"Error", failureMessage}, } message := messageFromMsgAndArgs(msgAndArgs...) if len(message) > 0 { content = append(content, labeledContent{"Messages", message}) } t.Errorf("%s", "\r"+getWhitespaceString()+labeledOutput(content...)) return false } type labeledContent struct { label string content string } // labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner: // // \r\t{{label}}:{{align_spaces}}\t{{content}}\n // // The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label. // If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this // alignment is achieved, "\t{{content}}\n" is added for the output. // // If the content of the labeledOutput contains line breaks, the subsequent lines are aligned so that they start at the same location as the first line. func labeledOutput(content ...labeledContent) string { longestLabel := 0 for _, v := range content { if len(v.label) > longestLabel { longestLabel = len(v.label) } } var output string for _, v := range content { output += "\r\t" + v.label + ":" + strings.Repeat(" ", longestLabel-len(v.label)) + "\t" + indentMessageLines(v.content, longestLabel) + "\n" } return output } // Implements asserts that an object is implemented by the specified interface. // // assert.Implements(t, (*MyInterface)(nil), new(MyObject)) func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { interfaceType := reflect.TypeOf(interfaceObject).Elem() if !reflect.TypeOf(object).Implements(interfaceType) { return Fail(t, fmt.Sprintf("%T must implement %v", object, interfaceType), msgAndArgs...) } return true } // IsType asserts that the specified objects are of the same type. func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) } return true } // Equal asserts that two objects are equal. // // assert.Equal(t, 123, 123) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { if err := validateEqualArgs(expected, actual); err != nil { return Fail(t, fmt.Sprintf("Invalid operation: %#v == %#v (%s)", expected, actual, err), msgAndArgs...) } if !ObjectsAreEqual(expected, actual) { diff := diff(expected, actual) expected, actual = formatUnequalValues(expected, actual) return Fail(t, fmt.Sprintf("Not equal: \n"+ "expected: %s\n"+ "actual: %s%s", expected, actual, diff), msgAndArgs...) } return true } // formatUnequalValues takes two values of arbitrary types and returns string // representations appropriate to be presented to the user. // // If the values are not of like type, the returned strings will be prefixed // with the type name, and the value will be enclosed in parenthesis similar // to a type conversion in the Go grammar. func formatUnequalValues(expected, actual interface{}) (e string, a string) { if reflect.TypeOf(expected) != reflect.TypeOf(actual) { return fmt.Sprintf("%T(%#v)", expected, expected), fmt.Sprintf("%T(%#v)", actual, actual) } return fmt.Sprintf("%#v", expected), fmt.Sprintf("%#v", actual) } // EqualValues asserts that two objects are equal or convertable to the same types // and equal. // // assert.EqualValues(t, uint32(123), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { if !ObjectsAreEqualValues(expected, actual) { diff := diff(expected, actual) expected, actual = formatUnequalValues(expected, actual) return Fail(t, fmt.Sprintf("Not equal: \n"+ "expected: %s\n"+ "actual: %s%s", expected, actual, diff), msgAndArgs...) } return true } // Exactly asserts that two objects are equal is value and type. // // assert.Exactly(t, int32(123), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { aType := reflect.TypeOf(expected) bType := reflect.TypeOf(actual) if aType != bType { return Fail(t, fmt.Sprintf("Types expected to match exactly\n\r\t%v != %v", aType, bType), msgAndArgs...) } return Equal(t, expected, actual, msgAndArgs...) } // NotNil asserts that the specified object is not nil. // // assert.NotNil(t, err) // // Returns whether the assertion was successful (true) or not (false). func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { if !isNil(object) { return true } return Fail(t, "Expected value not to be nil.", msgAndArgs...) } // isNil checks if a specified object is nil or not, without Failing. func isNil(object interface{}) bool { if object == nil { return true } value := reflect.ValueOf(object) kind := value.Kind() if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() { return true } return false } // Nil asserts that the specified object is nil. // // assert.Nil(t, err) // // Returns whether the assertion was successful (true) or not (false). func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { if isNil(object) { return true } return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...) } var numericZeros = []interface{}{ int(0), int8(0), int16(0), int32(0), int64(0), uint(0), uint8(0), uint16(0), uint32(0), uint64(0), float32(0), float64(0), } // isEmpty gets whether the specified object is considered empty or not. func isEmpty(object interface{}) bool { if object == nil { return true } else if object == "" { return true } else if object == false { return true } for _, v := range numericZeros { if object == v { return true } } objValue := reflect.ValueOf(object) switch objValue.Kind() { case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: { return (objValue.Len() == 0) } case reflect.Struct: switch object.(type) { case time.Time: return object.(time.Time).IsZero() } case reflect.Ptr: { if objValue.IsNil() { return true } switch object.(type) { case *time.Time: return object.(*time.Time).IsZero() default: return false } } } return false } // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // assert.Empty(t, obj) // // Returns whether the assertion was successful (true) or not (false). func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { pass := isEmpty(object) if !pass { Fail(t, fmt.Sprintf("Should be empty, but was %v", object), msgAndArgs...) } return pass } // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if assert.NotEmpty(t, obj) { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { pass := !isEmpty(object) if !pass { Fail(t, fmt.Sprintf("Should NOT be empty, but was %v", object), msgAndArgs...) } return pass } // getLen try to get length of object. // return (false, 0) if impossible. func getLen(x interface{}) (ok bool, length int) { v := reflect.ValueOf(x) defer func() { if e := recover(); e != nil { ok = false } }() return true, v.Len() } // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // // assert.Len(t, mySlice, 3) // // Returns whether the assertion was successful (true) or not (false). func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { ok, l := getLen(object) if !ok { return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) } if l != length { return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) } return true } // True asserts that the specified value is true. // // assert.True(t, myBool) // // Returns whether the assertion was successful (true) or not (false). func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { if value != true { return Fail(t, "Should be true", msgAndArgs...) } return true } // False asserts that the specified value is false. // // assert.False(t, myBool) // // Returns whether the assertion was successful (true) or not (false). func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { if value != false { return Fail(t, "Should be false", msgAndArgs...) } return true } // NotEqual asserts that the specified values are NOT equal. // // assert.NotEqual(t, obj1, obj2) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { if err := validateEqualArgs(expected, actual); err != nil { return Fail(t, fmt.Sprintf("Invalid operation: %#v != %#v (%s)", expected, actual, err), msgAndArgs...) } if ObjectsAreEqual(expected, actual) { return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) } return true } // containsElement try loop over the list check if the list includes the element. // return (false, false) if impossible. // return (true, false) if element was not found. // return (true, true) if element was found. func includeElement(list interface{}, element interface{}) (ok, found bool) { listValue := reflect.ValueOf(list) elementValue := reflect.ValueOf(element) defer func() { if e := recover(); e != nil { ok = false found = false } }() if reflect.TypeOf(list).Kind() == reflect.String { return true, strings.Contains(listValue.String(), elementValue.String()) } if reflect.TypeOf(list).Kind() == reflect.Map { mapKeys := listValue.MapKeys() for i := 0; i < len(mapKeys); i++ { if ObjectsAreEqual(mapKeys[i].Interface(), element) { return true, true } } return true, false } for i := 0; i < listValue.Len(); i++ { if ObjectsAreEqual(listValue.Index(i).Interface(), element) { return true, true } } return true, false } // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // assert.Contains(t, "Hello World", "World") // assert.Contains(t, ["Hello", "World"], "World") // assert.Contains(t, {"Hello": "World"}, "Hello") // // Returns whether the assertion was successful (true) or not (false). func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { ok, found := includeElement(s, contains) if !ok { return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) } if !found { return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", s, contains), msgAndArgs...) } return true } // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // assert.NotContains(t, "Hello World", "Earth") // assert.NotContains(t, ["Hello", "World"], "Earth") // assert.NotContains(t, {"Hello": "World"}, "Earth") // // Returns whether the assertion was successful (true) or not (false). func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { ok, found := includeElement(s, contains) if !ok { return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) } if found { return Fail(t, fmt.Sprintf("\"%s\" should not contain \"%s\"", s, contains), msgAndArgs...) } return true } // Subset asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { if subset == nil { return true // we consider nil to be equal to the nil set } subsetValue := reflect.ValueOf(subset) defer func() { if e := recover(); e != nil { ok = false } }() listKind := reflect.TypeOf(list).Kind() subsetKind := reflect.TypeOf(subset).Kind() if listKind != reflect.Array && listKind != reflect.Slice { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) } if subsetKind != reflect.Array && subsetKind != reflect.Slice { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } for i := 0; i < subsetValue.Len(); i++ { element := subsetValue.Index(i).Interface() ok, found := includeElement(list, element) if !ok { return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) } if !found { return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, element), msgAndArgs...) } } return true } // NotSubset asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { if subset == nil { return false // we consider nil to be equal to the nil set } subsetValue := reflect.ValueOf(subset) defer func() { if e := recover(); e != nil { ok = false } }() listKind := reflect.TypeOf(list).Kind() subsetKind := reflect.TypeOf(subset).Kind() if listKind != reflect.Array && listKind != reflect.Slice { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...) } if subsetKind != reflect.Array && subsetKind != reflect.Slice { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } for i := 0; i < subsetValue.Len(); i++ { element := subsetValue.Index(i).Interface() ok, found := includeElement(list, element) if !ok { return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) } if !found { return true } } return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...) } // Condition uses a Comparison to assert a complex condition. func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { result := comp() if !result { Fail(t, "Condition failed!", msgAndArgs...) } return result } // PanicTestFunc defines a func that should be passed to the assert.Panics and assert.NotPanics // methods, and represents a simple func that takes no arguments, and returns nothing. type PanicTestFunc func() // didPanic returns true if the function passed to it panics. Otherwise, it returns false. func didPanic(f PanicTestFunc) (bool, interface{}) { didPanic := false var message interface{} func() { defer func() { if message = recover(); message != nil { didPanic = true } }() // call the target function f() }() return didPanic, message } // Panics asserts that the code inside the specified PanicTestFunc panics. // // assert.Panics(t, func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { if funcDidPanic, panicValue := didPanic(f); !funcDidPanic { return Fail(t, fmt.Sprintf("func %#v should panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) } return true } // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { funcDidPanic, panicValue := didPanic(f) if !funcDidPanic { return Fail(t, fmt.Sprintf("func %#v should panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) } if panicValue != expected { return Fail(t, fmt.Sprintf("func %#v should panic with value:\t%v\n\r\tPanic value:\t%v", f, expected, panicValue), msgAndArgs...) } return true } // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // // assert.NotPanics(t, func(){ RemainCalm() }) // // Returns whether the assertion was successful (true) or not (false). func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { if funcDidPanic, panicValue := didPanic(f); funcDidPanic { return Fail(t, fmt.Sprintf("func %#v should not panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) } return true } // WithinDuration asserts that the two times are within duration delta of each other. // // assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) // // Returns whether the assertion was successful (true) or not (false). func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { dt := expected.Sub(actual) if dt < -delta || dt > delta { return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) } return true } func toFloat(x interface{}) (float64, bool) { var xf float64 xok := true switch xn := x.(type) { case uint8: xf = float64(xn) case uint16: xf = float64(xn) case uint32: xf = float64(xn) case uint64: xf = float64(xn) case int: xf = float64(xn) case int8: xf = float64(xn) case int16: xf = float64(xn) case int32: xf = float64(xn) case int64: xf = float64(xn) case float32: xf = float64(xn) case float64: xf = float64(xn) case time.Duration: xf = float64(xn) default: xok = false } return xf, xok } // InDelta asserts that the two numerals are within delta of each other. // // assert.InDelta(t, math.Pi, (22 / 7.0), 0.01) // // Returns whether the assertion was successful (true) or not (false). func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { af, aok := toFloat(expected) bf, bok := toFloat(actual) if !aok || !bok { return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...) } if math.IsNaN(af) { return Fail(t, fmt.Sprintf("Expected must not be NaN"), msgAndArgs...) } if math.IsNaN(bf) { return Fail(t, fmt.Sprintf("Expected %v with delta %v, but was NaN", expected, delta), msgAndArgs...) } dt := af - bf if dt < -delta || dt > delta { return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) } return true } // InDeltaSlice is the same as InDelta, except it compares two slices. func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { if expected == nil || actual == nil || reflect.TypeOf(actual).Kind() != reflect.Slice || reflect.TypeOf(expected).Kind() != reflect.Slice { return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...) } actualSlice := reflect.ValueOf(actual) expectedSlice := reflect.ValueOf(expected) for i := 0; i < actualSlice.Len(); i++ { result := InDelta(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), delta, msgAndArgs...) if !result { return result } } return true } func calcRelativeError(expected, actual interface{}) (float64, error) { af, aok := toFloat(expected) if !aok { return 0, fmt.Errorf("expected value %q cannot be converted to float", expected) } if af == 0 { return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error") } bf, bok := toFloat(actual) if !bok { return 0, fmt.Errorf("actual value %q cannot be converted to float", actual) } return math.Abs(af-bf) / math.Abs(af), nil } // InEpsilon asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { actualEpsilon, err := calcRelativeError(expected, actual) if err != nil { return Fail(t, err.Error(), msgAndArgs...) } if actualEpsilon > epsilon { return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+ " < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...) } return true } // InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { if expected == nil || actual == nil || reflect.TypeOf(actual).Kind() != reflect.Slice || reflect.TypeOf(expected).Kind() != reflect.Slice { return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...) } actualSlice := reflect.ValueOf(actual) expectedSlice := reflect.ValueOf(expected) for i := 0; i < actualSlice.Len(); i++ { result := InEpsilon(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), epsilon) if !result { return result } } return true } /* Errors */ // NoError asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if assert.NoError(t, err) { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { if err != nil { return Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...) } return true } // Error asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if assert.Error(t, err) { // assert.Equal(t, expectedError, err) // } // // Returns whether the assertion was successful (true) or not (false). func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { if err == nil { return Fail(t, "An error is expected but got nil.", msgAndArgs...) } return true } // EqualError asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // assert.EqualError(t, err, expectedErrorString) // // Returns whether the assertion was successful (true) or not (false). func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { if !Error(t, theError, msgAndArgs...) { return false } expected := errString actual := theError.Error() // don't need to use deep equals here, we know they are both strings if expected != actual { return Fail(t, fmt.Sprintf("Error message not equal:\n"+ "expected: %q\n"+ "actual: %q", expected, actual), msgAndArgs...) } return true } // matchRegexp return true if a specified regexp matches a string. func matchRegexp(rx interface{}, str interface{}) bool { var r *regexp.Regexp if rr, ok := rx.(*regexp.Regexp); ok { r = rr } else { r = regexp.MustCompile(fmt.Sprint(rx)) } return (r.FindStringIndex(fmt.Sprint(str)) != nil) } // Regexp asserts that a specified regexp matches a string. // // assert.Regexp(t, regexp.MustCompile("start"), "it's starting") // assert.Regexp(t, "start...$", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { match := matchRegexp(rx, str) if !match { Fail(t, fmt.Sprintf("Expect \"%v\" to match \"%v\"", str, rx), msgAndArgs...) } return match } // NotRegexp asserts that a specified regexp does not match a string. // // assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") // assert.NotRegexp(t, "^start", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { match := matchRegexp(rx, str) if match { Fail(t, fmt.Sprintf("Expect \"%v\" to NOT match \"%v\"", str, rx), msgAndArgs...) } return !match } // Zero asserts that i is the zero value for its type and returns the truth. func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...) } return true } // NotZero asserts that i is not the zero value for its type and returns the truth. func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...) } return true } // JSONEq asserts that two JSON strings are equivalent. // // assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) // // Returns whether the assertion was successful (true) or not (false). func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { var expectedJSONAsInterface, actualJSONAsInterface interface{} if err := json.Unmarshal([]byte(expected), &expectedJSONAsInterface); err != nil { return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...) } if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil { return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...) } return Equal(t, expectedJSONAsInterface, actualJSONAsInterface, msgAndArgs...) } func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { t := reflect.TypeOf(v) k := t.Kind() if k == reflect.Ptr { t = t.Elem() k = t.Kind() } return t, k } // diff returns a diff of both values as long as both are of the same type and // are a struct, map, slice or array. Otherwise it returns an empty string. func diff(expected interface{}, actual interface{}) string { if expected == nil || actual == nil { return "" } et, ek := typeAndKind(expected) at, _ := typeAndKind(actual) if et != at { return "" } if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { return "" } e := spewConfig.Sdump(expected) a := spewConfig.Sdump(actual) diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ A: difflib.SplitLines(e), B: difflib.SplitLines(a), FromFile: "Expected", FromDate: "", ToFile: "Actual", ToDate: "", Context: 1, }) return "\n\nDiff:\n" + diff } // validateEqualArgs checks whether provided arguments can be safely used in the // Equal/NotEqual functions. func validateEqualArgs(expected, actual interface{}) error { if isFunction(expected) || isFunction(actual) { return errors.New("cannot take func type as argument") } return nil } func isFunction(arg interface{}) bool { if arg == nil { return false } return reflect.TypeOf(arg).Kind() == reflect.Func } var spewConfig = spew.ConfigState{ Indent: " ", DisablePointerAddresses: true, DisableCapacities: true, SortKeys: true, } ================================================ FILE: vendor/github.com/stretchr/testify/assert/doc.go ================================================ // Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. // // Example Usage // // The following is a complete example using assert in a standard test function: // import ( // "testing" // "github.com/stretchr/testify/assert" // ) // // func TestSomething(t *testing.T) { // // var a string = "Hello" // var b string = "Hello" // // assert.Equal(t, a, b, "The two words should be the same.") // // } // // if you assert many times, use the format below: // // import ( // "testing" // "github.com/stretchr/testify/assert" // ) // // func TestSomething(t *testing.T) { // assert := assert.New(t) // // var a string = "Hello" // var b string = "Hello" // // assert.Equal(a, b, "The two words should be the same.") // } // // Assertions // // Assertions allow you to easily write test code, and are global funcs in the `assert` package. // All assertion functions take, as the first argument, the `*testing.T` object provided by the // testing framework. This allows the assertion funcs to write the failings and other details to // the correct place. // // Every assertion function also takes an optional string message as the final argument, // allowing custom error messages to be appended to the message the assertion method outputs. package assert ================================================ FILE: vendor/github.com/stretchr/testify/assert/errors.go ================================================ package assert import ( "errors" ) // AnError is an error instance useful for testing. If the code does not care // about error specifics, and only needs to return the error for example, this // error should be used to make the test code more readable. var AnError = errors.New("assert.AnError general error for testing") ================================================ FILE: vendor/github.com/stretchr/testify/assert/forward_assertions.go ================================================ package assert // Assertions provides assertion methods around the // TestingT interface. type Assertions struct { t TestingT } // New makes a new Assertions object for the specified TestingT. func New(t TestingT) *Assertions { return &Assertions{ t: t, } } //go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_forward.go.tmpl -include-format-funcs ================================================ FILE: vendor/github.com/stretchr/testify/assert/http_assertions.go ================================================ package assert import ( "fmt" "net/http" "net/http/httptest" "net/url" "strings" ) // httpCode is a helper that returns HTTP code of the response. It returns -1 and // an error if building a new request fails. func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) { w := httptest.NewRecorder() req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) if err != nil { return -1, err } handler(w, req) return w.Code, nil } // HTTPSuccess asserts that a specified handler returns a success status code. // // assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values) bool { code, err := httpCode(handler, method, url, values) if err != nil { Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) return false } isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent if !isSuccessCode { Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code)) } return isSuccessCode } // HTTPRedirect asserts that a specified handler returns a redirect status code. // // assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values) bool { code, err := httpCode(handler, method, url, values) if err != nil { Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) return false } isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect if !isRedirectCode { Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code)) } return isRedirectCode } // HTTPError asserts that a specified handler returns an error status code. // // assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values) bool { code, err := httpCode(handler, method, url, values) if err != nil { Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) return false } isErrorCode := code >= http.StatusBadRequest if !isErrorCode { Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code)) } return isErrorCode } // HTTPBody is a helper that returns HTTP body of the response. It returns // empty string if building a new request fails. func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string { w := httptest.NewRecorder() req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) if err != nil { return "" } handler(w, req) return w.Body.String() } // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // // assert.HTTPBodyContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}) bool { body := HTTPBody(handler, method, url, values) contains := strings.Contains(body, fmt.Sprint(str)) if !contains { Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) } return contains } // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // // assert.HTTPBodyNotContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}) bool { body := HTTPBody(handler, method, url, values) contains := strings.Contains(body, fmt.Sprint(str)) if contains { Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) } return !contains } ================================================ FILE: vendor/github.com/stretchr/testify/mock/doc.go ================================================ // Package mock provides a system by which it is possible to mock your objects // and verify calls are happening as expected. // // Example Usage // // The mock package provides an object, Mock, that tracks activity on another object. It is usually // embedded into a test object as shown below: // // type MyTestObject struct { // // add a Mock object instance // mock.Mock // // // other fields go here as normal // } // // When implementing the methods of an interface, you wire your functions up // to call the Mock.Called(args...) method, and return the appropriate values. // // For example, to mock a method that saves the name and age of a person and returns // the year of their birth or an error, you might write this: // // func (o *MyTestObject) SavePersonDetails(firstname, lastname string, age int) (int, error) { // args := o.Called(firstname, lastname, age) // return args.Int(0), args.Error(1) // } // // The Int, Error and Bool methods are examples of strongly typed getters that take the argument // index position. Given this argument list: // // (12, true, "Something") // // You could read them out strongly typed like this: // // args.Int(0) // args.Bool(1) // args.String(2) // // For objects of your own type, use the generic Arguments.Get(index) method and make a type assertion: // // return args.Get(0).(*MyObject), args.Get(1).(*AnotherObjectOfMine) // // This may cause a panic if the object you are getting is nil (the type assertion will fail), in those // cases you should check for nil first. package mock ================================================ FILE: vendor/github.com/stretchr/testify/mock/mock.go ================================================ package mock import ( "fmt" "reflect" "regexp" "runtime" "strings" "sync" "time" "github.com/davecgh/go-spew/spew" "github.com/pmezard/go-difflib/difflib" "github.com/stretchr/objx" "github.com/stretchr/testify/assert" ) // TestingT is an interface wrapper around *testing.T type TestingT interface { Logf(format string, args ...interface{}) Errorf(format string, args ...interface{}) FailNow() } /* Call */ // Call represents a method call and is used for setting expectations, // as well as recording activity. type Call struct { Parent *Mock // The name of the method that was or will be called. Method string // Holds the arguments of the method. Arguments Arguments // Holds the arguments that should be returned when // this method is called. ReturnArguments Arguments // The number of times to return the return arguments when setting // expectations. 0 means to always return the value. Repeatability int // Amount of times this call has been called totalCalls int // Holds a channel that will be used to block the Return until it either // receives a message or is closed. nil means it returns immediately. WaitFor <-chan time.Time // Holds a handler used to manipulate arguments content that are passed by // reference. It's useful when mocking methods such as unmarshalers or // decoders. RunFn func(Arguments) } func newCall(parent *Mock, methodName string, methodArguments ...interface{}) *Call { return &Call{ Parent: parent, Method: methodName, Arguments: methodArguments, ReturnArguments: make([]interface{}, 0), Repeatability: 0, WaitFor: nil, RunFn: nil, } } func (c *Call) lock() { c.Parent.mutex.Lock() } func (c *Call) unlock() { c.Parent.mutex.Unlock() } // Return specifies the return arguments for the expectation. // // Mock.On("DoSomething").Return(errors.New("failed")) func (c *Call) Return(returnArguments ...interface{}) *Call { c.lock() defer c.unlock() c.ReturnArguments = returnArguments return c } // Once indicates that that the mock should only return the value once. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() func (c *Call) Once() *Call { return c.Times(1) } // Twice indicates that that the mock should only return the value twice. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() func (c *Call) Twice() *Call { return c.Times(2) } // Times indicates that that the mock should only return the indicated number // of times. // // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) func (c *Call) Times(i int) *Call { c.lock() defer c.unlock() c.Repeatability = i return c } // WaitUntil sets the channel that will block the mock's return until its closed // or a message is received. // // Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) func (c *Call) WaitUntil(w <-chan time.Time) *Call { c.lock() defer c.unlock() c.WaitFor = w return c } // After sets how long to block until the call returns // // Mock.On("MyMethod", arg1, arg2).After(time.Second) func (c *Call) After(d time.Duration) *Call { return c.WaitUntil(time.After(d)) } // Run sets a handler to be called before returning. It can be used when // mocking a method such as unmarshalers that takes a pointer to a struct and // sets properties in such struct // // Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(func(args Arguments) { // arg := args.Get(0).(*map[string]interface{}) // arg["foo"] = "bar" // }) func (c *Call) Run(fn func(args Arguments)) *Call { c.lock() defer c.unlock() c.RunFn = fn return c } // On chains a new expectation description onto the mocked interface. This // allows syntax like. // // Mock. // On("MyMethod", 1).Return(nil). // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } // Mock is the workhorse used to track activity on another object. // For an example of its usage, refer to the "Example Usage" section at the top // of this document. type Mock struct { // Represents the calls that are expected of // an object. ExpectedCalls []*Call // Holds the calls that were made to this mocked object. Calls []Call // TestData holds any data that might be useful for testing. Testify ignores // this data completely allowing you to do whatever you like with it. testData objx.Map mutex sync.Mutex } // TestData holds any data that might be useful for testing. Testify ignores // this data completely allowing you to do whatever you like with it. func (m *Mock) TestData() objx.Map { if m.testData == nil { m.testData = make(objx.Map) } return m.testData } /* Setting expectations */ // On starts a description of an expectation of the specified method // being called. // // Mock.On("MyMethod", arg1, arg2) func (m *Mock) On(methodName string, arguments ...interface{}) *Call { for _, arg := range arguments { if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) } } m.mutex.Lock() defer m.mutex.Unlock() c := newCall(m, methodName, arguments...) m.ExpectedCalls = append(m.ExpectedCalls, c) return c } // /* // Recording and responding to activity // */ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { for i, call := range m.ExpectedCalls { if call.Method == method && call.Repeatability > -1 { _, diffCount := call.Arguments.Diff(arguments) if diffCount == 0 { return i, call } } } return -1, nil } func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) { diffCount := 0 var closestCall *Call for _, call := range m.expectedCalls() { if call.Method == method { _, tempDiffCount := call.Arguments.Diff(arguments) if tempDiffCount < diffCount || diffCount == 0 { diffCount = tempDiffCount closestCall = call } } } if closestCall == nil { return false, nil } return true, closestCall } func callString(method string, arguments Arguments, includeArgumentValues bool) string { var argValsString string if includeArgumentValues { var argVals []string for argIndex, arg := range arguments { argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg)) } argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t")) } return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString) } // Called tells the mock object that a method has been called, and gets an array // of arguments to return. Panics if the call is unexpected (i.e. not preceded by // appropriate .On .Return() calls) // If Call.WaitFor is set, blocks until the channel is closed or receives a message. func (m *Mock) Called(arguments ...interface{}) Arguments { // get the calling function's name pc, _, _, ok := runtime.Caller(1) if !ok { panic("Couldn't get the caller information") } functionPath := runtime.FuncForPC(pc).Name() //Next four lines are required to use GCCGO function naming conventions. //For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock //uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree //With GCCGO we need to remove interface information starting from pN

. re := regexp.MustCompile("\\.pN\\d+_") if re.MatchString(functionPath) { functionPath = re.Split(functionPath, -1)[0] } parts := strings.Split(functionPath, ".") functionName := parts[len(parts)-1] return m.MethodCalled(functionName, arguments...) } // MethodCalled tells the mock object that the given method has been called, and gets // an array of arguments to return. Panics if the call is unexpected (i.e. not preceded // by appropriate .On .Return() calls) // If Call.WaitFor is set, blocks until the channel is closed or receives a message. func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments { m.mutex.Lock() found, call := m.findExpectedCall(methodName, arguments...) if found < 0 { // we have to fail here - because we don't know what to do // as the return arguments. This is because: // // a) this is a totally unexpected call to this method, // b) the arguments are not what was expected, or // c) the developer has forgotten to add an accompanying On...Return pair. closestFound, closestCall := m.findClosestCall(methodName, arguments...) m.mutex.Unlock() if closestFound { panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\n", callString(methodName, arguments, true), callString(methodName, closestCall.Arguments, true), diffArguments(arguments, closestCall.Arguments))) } else { panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo())) } } switch { case call.Repeatability == 1: call.Repeatability = -1 call.totalCalls++ case call.Repeatability > 1: call.Repeatability-- call.totalCalls++ case call.Repeatability == 0: call.totalCalls++ } // add the call m.Calls = append(m.Calls, *newCall(m, methodName, arguments...)) m.mutex.Unlock() // block if specified if call.WaitFor != nil { <-call.WaitFor } m.mutex.Lock() runFn := call.RunFn m.mutex.Unlock() if runFn != nil { runFn(arguments) } m.mutex.Lock() returnArgs := call.ReturnArguments m.mutex.Unlock() return returnArgs } /* Assertions */ type assertExpectationser interface { AssertExpectations(TestingT) bool } // AssertExpectationsForObjects asserts that everything specified with On and Return // of the specified objects was in fact called as expected. // // Calls may have occurred in any order. func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool { for _, obj := range testObjects { if m, ok := obj.(Mock); ok { t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)") obj = &m } m := obj.(assertExpectationser) if !m.AssertExpectations(t) { return false } } return true } // AssertExpectations asserts that everything specified with On and Return was // in fact called as expected. Calls may have occurred in any order. func (m *Mock) AssertExpectations(t TestingT) bool { m.mutex.Lock() defer m.mutex.Unlock() var somethingMissing bool var failedExpectations int // iterate through each expectation expectedCalls := m.expectedCalls() for _, expectedCall := range expectedCalls { if !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 { somethingMissing = true failedExpectations++ t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } else { if expectedCall.Repeatability > 0 { somethingMissing = true failedExpectations++ } else { t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) } } } if somethingMissing { t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo()) } return !somethingMissing } // AssertNumberOfCalls asserts that the method was called expectedCalls times. func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool { m.mutex.Lock() defer m.mutex.Unlock() var actualCalls int for _, call := range m.calls() { if call.Method == methodName { actualCalls++ } } return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls)) } // AssertCalled asserts that the method was called. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool { m.mutex.Lock() defer m.mutex.Unlock() if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) { t.Logf("%v", m.expectedCalls()) return false } return true } // AssertNotCalled asserts that the method was not called. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool { m.mutex.Lock() defer m.mutex.Unlock() if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) { t.Logf("%v", m.expectedCalls()) return false } return true } func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { for _, call := range m.calls() { if call.Method == methodName { _, differences := Arguments(expected).Diff(call.Arguments) if differences == 0 { // found the expected call return true } } } // we didn't find the expected call return false } func (m *Mock) expectedCalls() []*Call { return append([]*Call{}, m.ExpectedCalls...) } func (m *Mock) calls() []Call { return append([]Call{}, m.Calls...) } /* Arguments */ // Arguments holds an array of method arguments or return values. type Arguments []interface{} const ( // Anything is used in Diff and Assert when the argument being tested // shouldn't be taken into consideration. Anything string = "mock.Anything" ) // AnythingOfTypeArgument is a string that contains the type of an argument // for use when type checking. Used in Diff and Assert. type AnythingOfTypeArgument string // AnythingOfType returns an AnythingOfTypeArgument object containing the // name of the type to check for. Used in Diff and Assert. // // For example: // Assert(t, AnythingOfType("string"), AnythingOfType("int")) func AnythingOfType(t string) AnythingOfTypeArgument { return AnythingOfTypeArgument(t) } // argumentMatcher performs custom argument matching, returning whether or // not the argument is matched by the expectation fixture function. type argumentMatcher struct { // fn is a function which accepts one argument, and returns a bool. fn reflect.Value } func (f argumentMatcher) Matches(argument interface{}) bool { expectType := f.fn.Type().In(0) if reflect.TypeOf(argument).AssignableTo(expectType) { result := f.fn.Call([]reflect.Value{reflect.ValueOf(argument)}) return result[0].Bool() } return false } func (f argumentMatcher) String() string { return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name()) } // MatchedBy can be used to match a mock call based on only certain properties // from a complex struct or some calculation. It takes a function that will be // evaluated with the called argument and will return true when there's a match // and false otherwise. // // Example: // m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) // // |fn|, must be a function accepting a single argument (of the expected type) // which returns a bool. If |fn| doesn't match the required signature, // MatchedBy() panics. func MatchedBy(fn interface{}) argumentMatcher { fnType := reflect.TypeOf(fn) if fnType.Kind() != reflect.Func { panic(fmt.Sprintf("assert: arguments: %s is not a func", fn)) } if fnType.NumIn() != 1 { panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn)) } if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool { panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) } return argumentMatcher{fn: reflect.ValueOf(fn)} } // Get Returns the argument at the specified index. func (args Arguments) Get(index int) interface{} { if index+1 > len(args) { panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args))) } return args[index] } // Is gets whether the objects match the arguments specified. func (args Arguments) Is(objects ...interface{}) bool { for i, obj := range args { if obj != objects[i] { return false } } return true } // Diff gets a string describing the differences between the arguments // and the specified objects. // // Returns the diff string and number of differences found. func (args Arguments) Diff(objects []interface{}) (string, int) { var output = "\n" var differences int var maxArgCount = len(args) if len(objects) > maxArgCount { maxArgCount = len(objects) } for i := 0; i < maxArgCount; i++ { var actual, expected interface{} if len(objects) <= i { actual = "(Missing)" } else { actual = objects[i] } if len(args) <= i { expected = "(Missing)" } else { expected = args[i] } if matcher, ok := expected.(argumentMatcher); ok { if matcher.Matches(actual) { output = fmt.Sprintf("%s\t%d: \u2705 %s matched by %s\n", output, i, actual, matcher) } else { differences++ output = fmt.Sprintf("%s\t%d: \u2705 %s not matched by %s\n", output, i, actual, matcher) } } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { // type checking if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) { // not match differences++ output = fmt.Sprintf("%s\t%d: \u274C type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actual) } } else { // normal checking if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { // match output = fmt.Sprintf("%s\t%d: \u2705 %s == %s\n", output, i, actual, expected) } else { // not match differences++ output = fmt.Sprintf("%s\t%d: \u274C %s != %s\n", output, i, actual, expected) } } } if differences == 0 { return "No differences.", differences } return output, differences } // Assert compares the arguments with the specified objects and fails if // they do not exactly match. func (args Arguments) Assert(t TestingT, objects ...interface{}) bool { // get the differences diff, diffCount := args.Diff(objects) if diffCount == 0 { return true } // there are differences... report them... t.Logf(diff) t.Errorf("%sArguments do not match.", assert.CallerInfo()) return false } // String gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. // // If no index is provided, String() returns a complete string representation // of the arguments. func (args Arguments) String(indexOrNil ...int) string { if len(indexOrNil) == 0 { // normal String() method - return a string representation of the args var argsStr []string for _, arg := range args { argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg))) } return strings.Join(argsStr, ",") } else if len(indexOrNil) == 1 { // Index has been specified - get the argument at that index var index = indexOrNil[0] var s string var ok bool if s, ok = args.Get(index).(string); !ok { panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index))) } return s } panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil))) } // Int gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. func (args Arguments) Int(index int) int { var s int var ok bool if s, ok = args.Get(index).(int); !ok { panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index))) } return s } // Error gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. func (args Arguments) Error(index int) error { obj := args.Get(index) var s error var ok bool if obj == nil { return nil } if s, ok = obj.(error); !ok { panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index))) } return s } // Bool gets the argument at the specified index. Panics if there is no argument, or // if the argument is of the wrong type. func (args Arguments) Bool(index int) bool { var s bool var ok bool if s, ok = args.Get(index).(bool); !ok { panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index))) } return s } func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { t := reflect.TypeOf(v) k := t.Kind() if k == reflect.Ptr { t = t.Elem() k = t.Kind() } return t, k } func diffArguments(expected Arguments, actual Arguments) string { if len(expected) != len(actual) { return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual)) } for x := range expected { if diffString := diff(expected[x], actual[x]); diffString != "" { return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString) } } return "" } // diff returns a diff of both values as long as both are of the same type and // are a struct, map, slice or array. Otherwise it returns an empty string. func diff(expected interface{}, actual interface{}) string { if expected == nil || actual == nil { return "" } et, ek := typeAndKind(expected) at, _ := typeAndKind(actual) if et != at { return "" } if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { return "" } e := spewConfig.Sdump(expected) a := spewConfig.Sdump(actual) diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ A: difflib.SplitLines(e), B: difflib.SplitLines(a), FromFile: "Expected", FromDate: "", ToFile: "Actual", ToDate: "", Context: 1, }) return diff } var spewConfig = spew.ConfigState{ Indent: " ", DisablePointerAddresses: true, DisableCapacities: true, SortKeys: true, } ================================================ FILE: vendor/github.com/stretchr/testify/require/doc.go ================================================ // Package require implements the same assertions as the `assert` package but // stops test execution when a test fails. // // Example Usage // // The following is a complete example using require in a standard test function: // import ( // "testing" // "github.com/stretchr/testify/require" // ) // // func TestSomething(t *testing.T) { // // var a string = "Hello" // var b string = "Hello" // // require.Equal(t, a, b, "The two words should be the same.") // // } // // Assertions // // The `require` package have same global functions as in the `assert` package, // but instead of returning a boolean result they call `t.FailNow()`. // // Every assertion function also takes an optional string message as the final argument, // allowing custom error messages to be appended to the message the assertion method outputs. package require ================================================ FILE: vendor/github.com/stretchr/testify/require/forward_requirements.go ================================================ package require // Assertions provides assertion methods around the // TestingT interface. type Assertions struct { t TestingT } // New makes a new Assertions object for the specified TestingT. func New(t TestingT) *Assertions { return &Assertions{ t: t, } } //go:generate go run ../_codegen/main.go -output-package=require -template=require_forward.go.tmpl -include-format-funcs ================================================ FILE: vendor/github.com/stretchr/testify/require/require.go ================================================ /* * CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen * THIS FILE MUST NOT BE EDITED BY HAND */ package require import ( assert "github.com/stretchr/testify/assert" http "net/http" url "net/url" time "time" ) // Condition uses a Comparison to assert a complex condition. func Condition(t TestingT, comp assert.Comparison, msgAndArgs ...interface{}) { if !assert.Condition(t, comp, msgAndArgs...) { t.FailNow() } } // Conditionf uses a Comparison to assert a complex condition. func Conditionf(t TestingT, comp assert.Comparison, msg string, args ...interface{}) { if !assert.Conditionf(t, comp, msg, args...) { t.FailNow() } } // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // assert.Contains(t, "Hello World", "World") // assert.Contains(t, ["Hello", "World"], "World") // assert.Contains(t, {"Hello": "World"}, "Hello") // // Returns whether the assertion was successful (true) or not (false). func Contains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { if !assert.Contains(t, s, contains, msgAndArgs...) { t.FailNow() } } // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") // assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") // assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { if !assert.Containsf(t, s, contains, msg, args...) { t.FailNow() } } // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // assert.Empty(t, obj) // // Returns whether the assertion was successful (true) or not (false). func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if !assert.Empty(t, object, msgAndArgs...) { t.FailNow() } } // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // assert.Emptyf(t, obj, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if !assert.Emptyf(t, object, msg, args...) { t.FailNow() } } // Equal asserts that two objects are equal. // // assert.Equal(t, 123, 123) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func Equal(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if !assert.Equal(t, expected, actual, msgAndArgs...) { t.FailNow() } } // EqualError asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // assert.EqualError(t, err, expectedErrorString) // // Returns whether the assertion was successful (true) or not (false). func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) { if !assert.EqualError(t, theError, errString, msgAndArgs...) { t.FailNow() } } // EqualErrorf asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) { if !assert.EqualErrorf(t, theError, errString, msg, args...) { t.FailNow() } } // EqualValues asserts that two objects are equal or convertable to the same types // and equal. // // assert.EqualValues(t, uint32(123), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if !assert.EqualValues(t, expected, actual, msgAndArgs...) { t.FailNow() } } // EqualValuesf asserts that two objects are equal or convertable to the same types // and equal. // // assert.EqualValuesf(t, uint32(123, "error message %s", "formatted"), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if !assert.EqualValuesf(t, expected, actual, msg, args...) { t.FailNow() } } // Equalf asserts that two objects are equal. // // assert.Equalf(t, 123, 123, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if !assert.Equalf(t, expected, actual, msg, args...) { t.FailNow() } } // Error asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if assert.Error(t, err) { // assert.Equal(t, expectedError, err) // } // // Returns whether the assertion was successful (true) or not (false). func Error(t TestingT, err error, msgAndArgs ...interface{}) { if !assert.Error(t, err, msgAndArgs...) { t.FailNow() } } // Errorf asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if assert.Errorf(t, err, "error message %s", "formatted") { // assert.Equal(t, expectedErrorf, err) // } // // Returns whether the assertion was successful (true) or not (false). func Errorf(t TestingT, err error, msg string, args ...interface{}) { if !assert.Errorf(t, err, msg, args...) { t.FailNow() } } // Exactly asserts that two objects are equal is value and type. // // assert.Exactly(t, int32(123), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func Exactly(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if !assert.Exactly(t, expected, actual, msgAndArgs...) { t.FailNow() } } // Exactlyf asserts that two objects are equal is value and type. // // assert.Exactlyf(t, int32(123, "error message %s", "formatted"), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if !assert.Exactlyf(t, expected, actual, msg, args...) { t.FailNow() } } // Fail reports a failure through func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) { if !assert.Fail(t, failureMessage, msgAndArgs...) { t.FailNow() } } // FailNow fails test func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) { if !assert.FailNow(t, failureMessage, msgAndArgs...) { t.FailNow() } } // FailNowf fails test func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) { if !assert.FailNowf(t, failureMessage, msg, args...) { t.FailNow() } } // Failf reports a failure through func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) { if !assert.Failf(t, failureMessage, msg, args...) { t.FailNow() } } // False asserts that the specified value is false. // // assert.False(t, myBool) // // Returns whether the assertion was successful (true) or not (false). func False(t TestingT, value bool, msgAndArgs ...interface{}) { if !assert.False(t, value, msgAndArgs...) { t.FailNow() } } // Falsef asserts that the specified value is false. // // assert.Falsef(t, myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Falsef(t TestingT, value bool, msg string, args ...interface{}) { if !assert.Falsef(t, value, msg, args...) { t.FailNow() } } // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // // assert.HTTPBodyContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { if !assert.HTTPBodyContains(t, handler, method, url, values, str) { t.FailNow() } } // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // // assert.HTTPBodyContainsf(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { if !assert.HTTPBodyContainsf(t, handler, method, url, values, str) { t.FailNow() } } // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // // assert.HTTPBodyNotContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { if !assert.HTTPBodyNotContains(t, handler, method, url, values, str) { t.FailNow() } } // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // // assert.HTTPBodyNotContainsf(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { if !assert.HTTPBodyNotContainsf(t, handler, method, url, values, str) { t.FailNow() } } // HTTPError asserts that a specified handler returns an error status code. // // assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPError(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) { if !assert.HTTPError(t, handler, method, url, values) { t.FailNow() } } // HTTPErrorf asserts that a specified handler returns an error status code. // // assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) { if !assert.HTTPErrorf(t, handler, method, url, values) { t.FailNow() } } // HTTPRedirect asserts that a specified handler returns a redirect status code. // // assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirect(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) { if !assert.HTTPRedirect(t, handler, method, url, values) { t.FailNow() } } // HTTPRedirectf asserts that a specified handler returns a redirect status code. // // assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) { if !assert.HTTPRedirectf(t, handler, method, url, values) { t.FailNow() } } // HTTPSuccess asserts that a specified handler returns a success status code. // // assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccess(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) { if !assert.HTTPSuccess(t, handler, method, url, values) { t.FailNow() } } // HTTPSuccessf asserts that a specified handler returns a success status code. // // assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values) { if !assert.HTTPSuccessf(t, handler, method, url, values) { t.FailNow() } } // Implements asserts that an object is implemented by the specified interface. // // assert.Implements(t, (*MyInterface)(nil), new(MyObject)) func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) { if !assert.Implements(t, interfaceObject, object, msgAndArgs...) { t.FailNow() } } // Implementsf asserts that an object is implemented by the specified interface. // // assert.Implementsf(t, (*MyInterface, "error message %s", "formatted")(nil), new(MyObject)) func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) { if !assert.Implementsf(t, interfaceObject, object, msg, args...) { t.FailNow() } } // InDelta asserts that the two numerals are within delta of each other. // // assert.InDelta(t, math.Pi, (22 / 7.0), 0.01) // // Returns whether the assertion was successful (true) or not (false). func InDelta(t TestingT, expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { if !assert.InDelta(t, expected, actual, delta, msgAndArgs...) { t.FailNow() } } // InDeltaSlice is the same as InDelta, except it compares two slices. func InDeltaSlice(t TestingT, expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { if !assert.InDeltaSlice(t, expected, actual, delta, msgAndArgs...) { t.FailNow() } } // InDeltaSlicef is the same as InDelta, except it compares two slices. func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { if !assert.InDeltaSlicef(t, expected, actual, delta, msg, args...) { t.FailNow() } } // InDeltaf asserts that the two numerals are within delta of each other. // // assert.InDeltaf(t, math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01) // // Returns whether the assertion was successful (true) or not (false). func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { if !assert.InDeltaf(t, expected, actual, delta, msg, args...) { t.FailNow() } } // InEpsilon asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func InEpsilon(t TestingT, expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) { if !assert.InEpsilon(t, expected, actual, epsilon, msgAndArgs...) { t.FailNow() } } // InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. func InEpsilonSlice(t TestingT, expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) { if !assert.InEpsilonSlice(t, expected, actual, epsilon, msgAndArgs...) { t.FailNow() } } // InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) { if !assert.InEpsilonSlicef(t, expected, actual, epsilon, msg, args...) { t.FailNow() } } // InEpsilonf asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) { if !assert.InEpsilonf(t, expected, actual, epsilon, msg, args...) { t.FailNow() } } // IsType asserts that the specified objects are of the same type. func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) { if !assert.IsType(t, expectedType, object, msgAndArgs...) { t.FailNow() } } // IsTypef asserts that the specified objects are of the same type. func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) { if !assert.IsTypef(t, expectedType, object, msg, args...) { t.FailNow() } } // JSONEq asserts that two JSON strings are equivalent. // // assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) // // Returns whether the assertion was successful (true) or not (false). func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) { if !assert.JSONEq(t, expected, actual, msgAndArgs...) { t.FailNow() } } // JSONEqf asserts that two JSON strings are equivalent. // // assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) { if !assert.JSONEqf(t, expected, actual, msg, args...) { t.FailNow() } } // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // // assert.Len(t, mySlice, 3) // // Returns whether the assertion was successful (true) or not (false). func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) { if !assert.Len(t, object, length, msgAndArgs...) { t.FailNow() } } // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // // assert.Lenf(t, mySlice, 3, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) { if !assert.Lenf(t, object, length, msg, args...) { t.FailNow() } } // Nil asserts that the specified object is nil. // // assert.Nil(t, err) // // Returns whether the assertion was successful (true) or not (false). func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) { if !assert.Nil(t, object, msgAndArgs...) { t.FailNow() } } // Nilf asserts that the specified object is nil. // // assert.Nilf(t, err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) { if !assert.Nilf(t, object, msg, args...) { t.FailNow() } } // NoError asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if assert.NoError(t, err) { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func NoError(t TestingT, err error, msgAndArgs ...interface{}) { if !assert.NoError(t, err, msgAndArgs...) { t.FailNow() } } // NoErrorf asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if assert.NoErrorf(t, err, "error message %s", "formatted") { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func NoErrorf(t TestingT, err error, msg string, args ...interface{}) { if !assert.NoErrorf(t, err, msg, args...) { t.FailNow() } } // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // assert.NotContains(t, "Hello World", "Earth") // assert.NotContains(t, ["Hello", "World"], "Earth") // assert.NotContains(t, {"Hello": "World"}, "Earth") // // Returns whether the assertion was successful (true) or not (false). func NotContains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { if !assert.NotContains(t, s, contains, msgAndArgs...) { t.FailNow() } } // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") // assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") // assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { if !assert.NotContainsf(t, s, contains, msg, args...) { t.FailNow() } } // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if assert.NotEmpty(t, obj) { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if !assert.NotEmpty(t, object, msgAndArgs...) { t.FailNow() } } // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if assert.NotEmptyf(t, obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if !assert.NotEmptyf(t, object, msg, args...) { t.FailNow() } } // NotEqual asserts that the specified values are NOT equal. // // assert.NotEqual(t, obj1, obj2) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func NotEqual(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if !assert.NotEqual(t, expected, actual, msgAndArgs...) { t.FailNow() } } // NotEqualf asserts that the specified values are NOT equal. // // assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if !assert.NotEqualf(t, expected, actual, msg, args...) { t.FailNow() } } // NotNil asserts that the specified object is not nil. // // assert.NotNil(t, err) // // Returns whether the assertion was successful (true) or not (false). func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) { if !assert.NotNil(t, object, msgAndArgs...) { t.FailNow() } } // NotNilf asserts that the specified object is not nil. // // assert.NotNilf(t, err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) { if !assert.NotNilf(t, object, msg, args...) { t.FailNow() } } // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // // assert.NotPanics(t, func(){ RemainCalm() }) // // Returns whether the assertion was successful (true) or not (false). func NotPanics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if !assert.NotPanics(t, f, msgAndArgs...) { t.FailNow() } } // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // // assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotPanicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{}) { if !assert.NotPanicsf(t, f, msg, args...) { t.FailNow() } } // NotRegexp asserts that a specified regexp does not match a string. // // assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") // assert.NotRegexp(t, "^start", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) { if !assert.NotRegexp(t, rx, str, msgAndArgs...) { t.FailNow() } } // NotRegexpf asserts that a specified regexp does not match a string. // // assert.NotRegexpf(t, regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting") // assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) { if !assert.NotRegexpf(t, rx, str, msg, args...) { t.FailNow() } } // NotSubset asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { if !assert.NotSubset(t, list, subset, msgAndArgs...) { t.FailNow() } } // NotSubsetf asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { if !assert.NotSubsetf(t, list, subset, msg, args...) { t.FailNow() } } // NotZero asserts that i is not the zero value for its type and returns the truth. func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) { if !assert.NotZero(t, i, msgAndArgs...) { t.FailNow() } } // NotZerof asserts that i is not the zero value for its type and returns the truth. func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) { if !assert.NotZerof(t, i, msg, args...) { t.FailNow() } } // Panics asserts that the code inside the specified PanicTestFunc panics. // // assert.Panics(t, func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func Panics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if !assert.Panics(t, f, msgAndArgs...) { t.FailNow() } } // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func PanicsWithValue(t TestingT, expected interface{}, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if !assert.PanicsWithValue(t, expected, f, msgAndArgs...) { t.FailNow() } } // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func PanicsWithValuef(t TestingT, expected interface{}, f assert.PanicTestFunc, msg string, args ...interface{}) { if !assert.PanicsWithValuef(t, expected, f, msg, args...) { t.FailNow() } } // Panicsf asserts that the code inside the specified PanicTestFunc panics. // // assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Panicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{}) { if !assert.Panicsf(t, f, msg, args...) { t.FailNow() } } // Regexp asserts that a specified regexp matches a string. // // assert.Regexp(t, regexp.MustCompile("start"), "it's starting") // assert.Regexp(t, "start...$", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) { if !assert.Regexp(t, rx, str, msgAndArgs...) { t.FailNow() } } // Regexpf asserts that a specified regexp matches a string. // // assert.Regexpf(t, regexp.MustCompile("start", "error message %s", "formatted"), "it's starting") // assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) { if !assert.Regexpf(t, rx, str, msg, args...) { t.FailNow() } } // Subset asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { if !assert.Subset(t, list, subset, msgAndArgs...) { t.FailNow() } } // Subsetf asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { if !assert.Subsetf(t, list, subset, msg, args...) { t.FailNow() } } // True asserts that the specified value is true. // // assert.True(t, myBool) // // Returns whether the assertion was successful (true) or not (false). func True(t TestingT, value bool, msgAndArgs ...interface{}) { if !assert.True(t, value, msgAndArgs...) { t.FailNow() } } // Truef asserts that the specified value is true. // // assert.Truef(t, myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func Truef(t TestingT, value bool, msg string, args ...interface{}) { if !assert.Truef(t, value, msg, args...) { t.FailNow() } } // WithinDuration asserts that the two times are within duration delta of each other. // // assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) // // Returns whether the assertion was successful (true) or not (false). func WithinDuration(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) { if !assert.WithinDuration(t, expected, actual, delta, msgAndArgs...) { t.FailNow() } } // WithinDurationf asserts that the two times are within duration delta of each other. // // assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) { if !assert.WithinDurationf(t, expected, actual, delta, msg, args...) { t.FailNow() } } // Zero asserts that i is the zero value for its type and returns the truth. func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) { if !assert.Zero(t, i, msgAndArgs...) { t.FailNow() } } // Zerof asserts that i is the zero value for its type and returns the truth. func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) { if !assert.Zerof(t, i, msg, args...) { t.FailNow() } } ================================================ FILE: vendor/github.com/stretchr/testify/require/require.go.tmpl ================================================ {{.Comment}} func {{.DocInfo.Name}}(t TestingT, {{.Params}}) { if !assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { t.FailNow() } } ================================================ FILE: vendor/github.com/stretchr/testify/require/require_forward.go ================================================ /* * CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen * THIS FILE MUST NOT BE EDITED BY HAND */ package require import ( assert "github.com/stretchr/testify/assert" http "net/http" url "net/url" time "time" ) // Condition uses a Comparison to assert a complex condition. func (a *Assertions) Condition(comp assert.Comparison, msgAndArgs ...interface{}) { Condition(a.t, comp, msgAndArgs...) } // Conditionf uses a Comparison to assert a complex condition. func (a *Assertions) Conditionf(comp assert.Comparison, msg string, args ...interface{}) { Conditionf(a.t, comp, msg, args...) } // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // a.Contains("Hello World", "World") // a.Contains(["Hello", "World"], "World") // a.Contains({"Hello": "World"}, "Hello") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) { Contains(a.t, s, contains, msgAndArgs...) } // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // // a.Containsf("Hello World", "World", "error message %s", "formatted") // a.Containsf(["Hello", "World"], "World", "error message %s", "formatted") // a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) { Containsf(a.t, s, contains, msg, args...) } // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // a.Empty(obj) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) { Empty(a.t, object, msgAndArgs...) } // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // // a.Emptyf(obj, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) { Emptyf(a.t, object, msg, args...) } // Equal asserts that two objects are equal. // // a.Equal(123, 123) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { Equal(a.t, expected, actual, msgAndArgs...) } // EqualError asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // a.EqualError(err, expectedErrorString) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) { EqualError(a.t, theError, errString, msgAndArgs...) } // EqualErrorf asserts that a function returned an error (i.e. not `nil`) // and that it is equal to the provided error. // // actualObj, err := SomeFunction() // a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) { EqualErrorf(a.t, theError, errString, msg, args...) } // EqualValues asserts that two objects are equal or convertable to the same types // and equal. // // a.EqualValues(uint32(123), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { EqualValues(a.t, expected, actual, msgAndArgs...) } // EqualValuesf asserts that two objects are equal or convertable to the same types // and equal. // // a.EqualValuesf(uint32(123, "error message %s", "formatted"), int32(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) { EqualValuesf(a.t, expected, actual, msg, args...) } // Equalf asserts that two objects are equal. // // a.Equalf(123, 123, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality // cannot be determined and will always fail. func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string, args ...interface{}) { Equalf(a.t, expected, actual, msg, args...) } // Error asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if a.Error(err) { // assert.Equal(t, expectedError, err) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Error(err error, msgAndArgs ...interface{}) { Error(a.t, err, msgAndArgs...) } // Errorf asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() // if a.Errorf(err, "error message %s", "formatted") { // assert.Equal(t, expectedErrorf, err) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Errorf(err error, msg string, args ...interface{}) { Errorf(a.t, err, msg, args...) } // Exactly asserts that two objects are equal is value and type. // // a.Exactly(int32(123), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { Exactly(a.t, expected, actual, msgAndArgs...) } // Exactlyf asserts that two objects are equal is value and type. // // a.Exactlyf(int32(123, "error message %s", "formatted"), int64(123)) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) { Exactlyf(a.t, expected, actual, msg, args...) } // Fail reports a failure through func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) { Fail(a.t, failureMessage, msgAndArgs...) } // FailNow fails test func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) { FailNow(a.t, failureMessage, msgAndArgs...) } // FailNowf fails test func (a *Assertions) FailNowf(failureMessage string, msg string, args ...interface{}) { FailNowf(a.t, failureMessage, msg, args...) } // Failf reports a failure through func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{}) { Failf(a.t, failureMessage, msg, args...) } // False asserts that the specified value is false. // // a.False(myBool) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) False(value bool, msgAndArgs ...interface{}) { False(a.t, value, msgAndArgs...) } // Falsef asserts that the specified value is false. // // a.Falsef(myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) { Falsef(a.t, value, msg, args...) } // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // // a.HTTPBodyContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { HTTPBodyContains(a.t, handler, method, url, values, str) } // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // // a.HTTPBodyContainsf(myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { HTTPBodyContainsf(a.t, handler, method, url, values, str) } // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // // a.HTTPBodyNotContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { HTTPBodyNotContains(a.t, handler, method, url, values, str) } // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // // a.HTTPBodyNotContainsf(myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) { HTTPBodyNotContainsf(a.t, handler, method, url, values, str) } // HTTPError asserts that a specified handler returns an error status code. // // a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values) { HTTPError(a.t, handler, method, url, values) } // HTTPErrorf asserts that a specified handler returns an error status code. // // a.HTTPErrorf(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values) { HTTPErrorf(a.t, handler, method, url, values) } // HTTPRedirect asserts that a specified handler returns a redirect status code. // // a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values) { HTTPRedirect(a.t, handler, method, url, values) } // HTTPRedirectf asserts that a specified handler returns a redirect status code. // // a.HTTPRedirectf(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values) { HTTPRedirectf(a.t, handler, method, url, values) } // HTTPSuccess asserts that a specified handler returns a success status code. // // a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values) { HTTPSuccess(a.t, handler, method, url, values) } // HTTPSuccessf asserts that a specified handler returns a success status code. // // a.HTTPSuccessf(myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values) { HTTPSuccessf(a.t, handler, method, url, values) } // Implements asserts that an object is implemented by the specified interface. // // a.Implements((*MyInterface)(nil), new(MyObject)) func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) { Implements(a.t, interfaceObject, object, msgAndArgs...) } // Implementsf asserts that an object is implemented by the specified interface. // // a.Implementsf((*MyInterface, "error message %s", "formatted")(nil), new(MyObject)) func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) { Implementsf(a.t, interfaceObject, object, msg, args...) } // InDelta asserts that the two numerals are within delta of each other. // // a.InDelta(math.Pi, (22 / 7.0), 0.01) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { InDelta(a.t, expected, actual, delta, msgAndArgs...) } // InDeltaSlice is the same as InDelta, except it compares two slices. func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...) } // InDeltaSlicef is the same as InDelta, except it compares two slices. func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { InDeltaSlicef(a.t, expected, actual, delta, msg, args...) } // InDeltaf asserts that the two numerals are within delta of each other. // // a.InDeltaf(math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { InDeltaf(a.t, expected, actual, delta, msg, args...) } // InEpsilon asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) { InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...) } // InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) { InEpsilonSlice(a.t, expected, actual, epsilon, msgAndArgs...) } // InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. func (a *Assertions) InEpsilonSlicef(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) { InEpsilonSlicef(a.t, expected, actual, epsilon, msg, args...) } // InEpsilonf asserts that expected and actual have a relative error less than epsilon // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) { InEpsilonf(a.t, expected, actual, epsilon, msg, args...) } // IsType asserts that the specified objects are of the same type. func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) { IsType(a.t, expectedType, object, msgAndArgs...) } // IsTypef asserts that the specified objects are of the same type. func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) { IsTypef(a.t, expectedType, object, msg, args...) } // JSONEq asserts that two JSON strings are equivalent. // // a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) { JSONEq(a.t, expected, actual, msgAndArgs...) } // JSONEqf asserts that two JSON strings are equivalent. // // a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) { JSONEqf(a.t, expected, actual, msg, args...) } // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // // a.Len(mySlice, 3) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) { Len(a.t, object, length, msgAndArgs...) } // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // // a.Lenf(mySlice, 3, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) { Lenf(a.t, object, length, msg, args...) } // Nil asserts that the specified object is nil. // // a.Nil(err) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) { Nil(a.t, object, msgAndArgs...) } // Nilf asserts that the specified object is nil. // // a.Nilf(err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) { Nilf(a.t, object, msg, args...) } // NoError asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if a.NoError(err) { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) { NoError(a.t, err, msgAndArgs...) } // NoErrorf asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() // if a.NoErrorf(err, "error message %s", "formatted") { // assert.Equal(t, expectedObj, actualObj) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) { NoErrorf(a.t, err, msg, args...) } // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // a.NotContains("Hello World", "Earth") // a.NotContains(["Hello", "World"], "Earth") // a.NotContains({"Hello": "World"}, "Earth") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) { NotContains(a.t, s, contains, msgAndArgs...) } // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // // a.NotContainsf("Hello World", "Earth", "error message %s", "formatted") // a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted") // a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) { NotContainsf(a.t, s, contains, msg, args...) } // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if a.NotEmpty(obj) { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) { NotEmpty(a.t, object, msgAndArgs...) } // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // // if a.NotEmptyf(obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) // } // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) { NotEmptyf(a.t, object, msg, args...) } // NotEqual asserts that the specified values are NOT equal. // // a.NotEqual(obj1, obj2) // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { NotEqual(a.t, expected, actual, msgAndArgs...) } // NotEqualf asserts that the specified values are NOT equal. // // a.NotEqualf(obj1, obj2, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg string, args ...interface{}) { NotEqualf(a.t, expected, actual, msg, args...) } // NotNil asserts that the specified object is not nil. // // a.NotNil(err) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) { NotNil(a.t, object, msgAndArgs...) } // NotNilf asserts that the specified object is not nil. // // a.NotNilf(err, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) { NotNilf(a.t, object, msg, args...) } // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // // a.NotPanics(func(){ RemainCalm() }) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotPanics(f assert.PanicTestFunc, msgAndArgs ...interface{}) { NotPanics(a.t, f, msgAndArgs...) } // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // // a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotPanicsf(f assert.PanicTestFunc, msg string, args ...interface{}) { NotPanicsf(a.t, f, msg, args...) } // NotRegexp asserts that a specified regexp does not match a string. // // a.NotRegexp(regexp.MustCompile("starts"), "it's starting") // a.NotRegexp("^start", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) { NotRegexp(a.t, rx, str, msgAndArgs...) } // NotRegexpf asserts that a specified regexp does not match a string. // // a.NotRegexpf(regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting") // a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) { NotRegexpf(a.t, rx, str, msg, args...) } // NotSubset asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { NotSubset(a.t, list, subset, msgAndArgs...) } // NotSubsetf asserts that the specified list(array, slice...) contains not all // elements given in the specified subset(array, slice...). // // a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { NotSubsetf(a.t, list, subset, msg, args...) } // NotZero asserts that i is not the zero value for its type and returns the truth. func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) { NotZero(a.t, i, msgAndArgs...) } // NotZerof asserts that i is not the zero value for its type and returns the truth. func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) { NotZerof(a.t, i, msg, args...) } // Panics asserts that the code inside the specified PanicTestFunc panics. // // a.Panics(func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Panics(f assert.PanicTestFunc, msgAndArgs ...interface{}) { Panics(a.t, f, msgAndArgs...) } // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // a.PanicsWithValue("crazy error", func(){ GoCrazy() }) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) PanicsWithValue(expected interface{}, f assert.PanicTestFunc, msgAndArgs ...interface{}) { PanicsWithValue(a.t, expected, f, msgAndArgs...) } // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // // a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) PanicsWithValuef(expected interface{}, f assert.PanicTestFunc, msg string, args ...interface{}) { PanicsWithValuef(a.t, expected, f, msg, args...) } // Panicsf asserts that the code inside the specified PanicTestFunc panics. // // a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Panicsf(f assert.PanicTestFunc, msg string, args ...interface{}) { Panicsf(a.t, f, msg, args...) } // Regexp asserts that a specified regexp matches a string. // // a.Regexp(regexp.MustCompile("start"), "it's starting") // a.Regexp("start...$", "it's not starting") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) { Regexp(a.t, rx, str, msgAndArgs...) } // Regexpf asserts that a specified regexp matches a string. // // a.Regexpf(regexp.MustCompile("start", "error message %s", "formatted"), "it's starting") // a.Regexpf("start...$", "it's not starting", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) { Regexpf(a.t, rx, str, msg, args...) } // Subset asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { Subset(a.t, list, subset, msgAndArgs...) } // Subsetf asserts that the specified list(array, slice...) contains all // elements given in the specified subset(array, slice...). // // a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { Subsetf(a.t, list, subset, msg, args...) } // True asserts that the specified value is true. // // a.True(myBool) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) True(value bool, msgAndArgs ...interface{}) { True(a.t, value, msgAndArgs...) } // Truef asserts that the specified value is true. // // a.Truef(myBool, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) Truef(value bool, msg string, args ...interface{}) { Truef(a.t, value, msg, args...) } // WithinDuration asserts that the two times are within duration delta of each other. // // a.WithinDuration(time.Now(), time.Now(), 10*time.Second) // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) { WithinDuration(a.t, expected, actual, delta, msgAndArgs...) } // WithinDurationf asserts that the two times are within duration delta of each other. // // a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) { WithinDurationf(a.t, expected, actual, delta, msg, args...) } // Zero asserts that i is the zero value for its type and returns the truth. func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) { Zero(a.t, i, msgAndArgs...) } // Zerof asserts that i is the zero value for its type and returns the truth. func (a *Assertions) Zerof(i interface{}, msg string, args ...interface{}) { Zerof(a.t, i, msg, args...) } ================================================ FILE: vendor/github.com/stretchr/testify/require/require_forward.go.tmpl ================================================ {{.CommentWithoutT "a"}} func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) { {{.DocInfo.Name}}(a.t, {{.ForwardedParams}}) } ================================================ FILE: vendor/github.com/stretchr/testify/require/requirements.go ================================================ package require // TestingT is an interface wrapper around *testing.T type TestingT interface { Errorf(format string, args ...interface{}) FailNow() } //go:generate go run ../_codegen/main.go -output-package=require -template=require.go.tmpl -include-format-funcs ================================================ FILE: vendor/go.uber.org/dig/CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] - No changes yet. ## [1.5.0] - 2018-09-19 ### Added - Added a `DeferAcyclicVerification` container option that defers graph cycle detection until the next Invoke. ### Changed - Improved cycle-detection performance by 50x in certain degenerative cases. ## [1.4.0] - 2018-08-16 ### Added - Added `Visualize` function to visualize the state of the container in the GraphViz DOT format. This allows visualization of error types and the dependency relationships of types in the container. - Added `CanVisualizeError` function to determine if an error can be visualized in the graph. - Added `Name` option for `Provide` to add named values to the container without rewriting constructors. See package documentation for more information. ### Changed - `name:"..."` tags on nested Result Objects will now cause errors instead of being ignored. ## [1.3.0] - 2017-12-04 ### Changed - Improved messages for errors thrown by Dig under a many scenarios to be more informative. ## [1.2.0] - 2017-11-07 ### Added - `dig.In` and `dig.Out` now support value groups, making it possible to produce many values of the same type from different constructors. See package documentation for more information. ## [1.1.0] - 2017-09-15 ### Added - Added the `dig.RootCause` function which allows retrieving the original constructor error that caused an `Invoke` failure. ### Changed - Errors from `Invoke` now attempt to hint to the user a presence of a similar type, for example a pointer to the requested type and vice versa. ## [1.0.0] - 2017-07-31 First stable release: no breaking changes will be made in the 1.x series. ### Changed - `Provide` and `Invoke` will now fail if `dig.In` or `dig.Out` structs contain unexported fields. Previously these fields were ignored which often led to confusion. ## [1.0.0-rc2] - 2017-07-21 ### Added - Exported `dig.IsIn` and `dig.IsOut` so that consuming libraries can check if a params or return struct embeds the `dig.In` and `dig.Out` types, respectively. ### Changed - Added variadic options to all public APIS so that new functionality can be introduced post v1.0.0 without introducing breaking changes. - Functions with variadic arguments can now be passed to `dig.Provide` and `dig.Invoke`. Previously this caused an error, whereas now the args will be ignored. ## [1.0.0-rc1] - 2017-06-21 First release candidate. ## [0.5.0] - 2017-06-19 ### Added - `dig.In` and `dig.Out` now support named instances, i.e.: ```go type param struct { dig.In DB1 DB.Connection `name:"primary"` DB2 DB.Connection `name:"secondary"` } ``` ### Fixed - Structs compatible with `dig.In` and `dig.Out` may now be generated using `reflect.StructOf`. ## [0.4.0] - 2017-06-12 ### Added - Add `dig.In` embeddable type for advanced use-cases of specifying dependencies. - Add `dig.Out` embeddable type for advanced use-cases of constructors inserting types in the container. - Add support for optional parameters through `optional:"true"` tag on `dig.In` objects. - Add support for value types and many built-ins (maps, slices, channels). ### Changed - **[Breaking]** Restrict the API surface to only `Provide` and `Invoke`. - **[Breaking]** Update `Provide` method to accept variadic arguments. ### Removed - **[Breaking]** Remove `Must*` funcs to greatly reduce API surface area. - Providing constructors with common returned types results in an error. ## [0.3] - 2017-05-02 ### Added - Add functionality to `Provide` to support constructor with `n` return objects to be resolved into the `dig.Graph` - Add `Invoke` function to invoke provided function and insert return objects into the `dig.Graph` ### Changed - Rename `RegisterAll` and `MustRegisterAll` to `ProvideAll` and `MustProvideAll`. ## [0.2] - 2017-03-27 ### Changed - Rename `Register` to `Provide` for clarity and to recude clash with other Register functions. - Rename `dig.Graph` to `dig.Container`. ### Removed - Remove the package-level functions and the `DefaultGraph`. ## 0.1 - 2017-03-23 Initial release. [Unreleased]: https://github.com/uber-go/dig/compare/v1.5.0...HEAD [1.5.0]: https://github.com/uber-go/dig/compare/v1.4.0...v1.5.0 [1.4.0]: https://github.com/uber-go/dig/compare/v1.3.0...v1.4.0 [1.3.0]: https://github.com/uber-go/dig/compare/v1.2.0...v1.3.0 [1.2.0]: https://github.com/uber-go/dig/compare/v1.1.0...v1.2.0 [1.1.0]: https://github.com/uber-go/dig/compare/v1.0.0...v1.1.0 [1.0.0]: https://github.com/uber-go/dig/compare/v1.0.0-rc2...v1.0.0 [1.0.0-rc2]: https://github.com/uber-go/dig/compare/v1.0.0-rc1...v1.0.0-rc2 [1.0.0-rc1]: https://github.com/uber-go/dig/compare/v0.5.0...v1.0.0-rc1 [0.5.0]: https://github.com/uber-go/dig/compare/v0.4.0...v0.5.0 [0.4.0]: https://github.com/uber-go/dig/compare/v0.3...v0.4.0 [0.3]: https://github.com/uber-go/dig/compare/v0.2...v0.3 [0.2]: https://github.com/uber-go/dig/compare/v0.1...v0.2 ================================================ FILE: vendor/go.uber.org/dig/LICENSE ================================================ Copyright (c) 2017-2018 Uber Technologies, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: vendor/go.uber.org/dig/Makefile ================================================ BENCH_FLAGS ?= -cpuprofile=cpu.pprof -memprofile=mem.pprof -benchmem PKGS ?= $(shell glide novendor | grep -v examples) PKG_FILES ?= *.go GO_VERSION := $(shell go version | cut -d " " -f 3) .PHONY: all all: lint test .PHONY: dependencies dependencies: @echo "Installing Glide and locked dependencies..." glide --version || go get -u -f github.com/Masterminds/glide glide install @echo "Installing uber-license tool..." command -v update-license >/dev/null || go get -u -f go.uber.org/tools/update-license @echo "Installing golint..." command -v golint >/dev/null || go get -u -f github.com/golang/lint/golint .PHONY: license license: dependencies ./check_license.sh | tee -a lint.log .PHONY: lint lint: @rm -rf lint.log @echo "Checking formatting..." @gofmt -d -s $(PKG_FILES) 2>&1 | tee lint.log @echo "Installing test dependencies for vet..." @go test -i $(PKGS) @echo "Checking vet..." @$(foreach dir,$(PKG_FILES),go tool vet $(VET_RULES) $(dir) 2>&1 | tee -a lint.log;) @echo "Checking lint..." @$(foreach dir,$(PKGS),golint $(dir) 2>&1 | tee -a lint.log;) @echo "Checking for unresolved FIXMEs..." @git grep -i fixme | grep -v -e vendor -e Makefile | tee -a lint.log @echo "Checking for license headers..." @DRY_RUN=1 ./check_license.sh | tee -a lint.log @[ ! -s lint.log ] .PHONY: test test: @.build/test.sh .PHONY: ci ci: SHELL := /bin/bash ci: test bash <(curl -s https://codecov.io/bash) .PHONY: bench BENCH ?= . bench: @$(foreach pkg,$(PKGS),go test -bench=$(BENCH) -run="^$$" $(BENCH_FLAGS) $(pkg);) ================================================ FILE: vendor/go.uber.org/dig/README.md ================================================ # :hammer: dig [![GoDoc][doc-img]][doc] [![GitHub release][release-img]][release] [![Build Status][ci-img]][ci] [![Coverage Status][cov-img]][cov] [![Go Report Card][report-card-img]][report-card] A reflection based dependency injection toolkit for Go. ### Good for: * Powering an application framework, e.g. [Fx](https://github.com/uber-go/fx). * Resolving the object graph during process startup. ### Bad for: * Using in place of an application framework, e.g. [Fx](https://github.com/uber-go/fx). * Resolving dependencies after the process has already started. * Exposing to user-land code as a [Service Locator](https://martinfowler.com/articles/injection.html#UsingAServiceLocator). ## Installation We recommend locking to [SemVer](http://semver.org/) range `^1` using [Glide](https://github.com/Masterminds/glide): ``` glide get 'go.uber.org/dig#^1' ``` ## Stability This library is `v1` and follows [SemVer](http://semver.org/) strictly. No breaking changes will be made to exported APIs before `v2.0.0`. [doc-img]: http://img.shields.io/badge/GoDoc-Reference-blue.svg [doc]: https://godoc.org/go.uber.org/dig [release-img]: https://img.shields.io/github/release/uber-go/dig.svg [release]: https://github.com/uber-go/dig/releases [ci-img]: https://img.shields.io/travis/uber-go/dig/master.svg [ci]: https://travis-ci.org/uber-go/dig/branches [cov-img]: https://codecov.io/gh/uber-go/dig/branch/master/graph/badge.svg [cov]: https://codecov.io/gh/uber-go/dig/branch/master [report-card-img]: https://goreportcard.com/badge/github.com/uber-go/dig [report-card]: https://goreportcard.com/report/github.com/uber-go/dig ================================================ FILE: vendor/go.uber.org/dig/check_license.sh ================================================ #!/bin/bash set -eo pipefail run_update_license() { # doing this because of SC2046 warning for file in $(git ls-files | grep '\.go$'); do update-license $@ "${file}" done } if [ -z "${DRY_RUN}" ]; then run_update_license else DRY_OUTPUT="$(run_update_license --dry)" if [ -n "${DRY_OUTPUT}" ]; then echo "The following files do not have correct license headers." echo "Please run make license and amend your commit." echo echo "${DRY_OUTPUT}" exit 1 fi fi ================================================ FILE: vendor/go.uber.org/dig/cycle.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "bytes" "fmt" "go.uber.org/dig/internal/digreflect" ) type cycleEntry struct { Key key Func *digreflect.Func } type errCycleDetected struct { Path []cycleEntry } func (e errCycleDetected) Error() string { // We get something like, // // foo provided by "path/to/package".NewFoo (path/to/file.go:42) // depends on bar provided by "another/package".NewBar (somefile.go:1) // depends on baz provided by "somepackage".NewBar (anotherfile.go:2) // depends on foo provided by "path/to/package".NewFoo (path/to/file.go:42) // b := new(bytes.Buffer) for i, entry := range e.Path { if i > 0 { b.WriteString("\n\tdepends on ") } fmt.Fprintf(b, "%v provided by %v", entry.Key, entry.Func) } return b.String() } // IsCycleDetected returns a boolean as to whether the provided error indicates // a cycle was detected in the container graph. func IsCycleDetected(err error) bool { _, ok := RootCause(err).(errCycleDetected) return ok } func verifyAcyclic(c containerStore, n provider, k key) error { visited := make(map[key]struct{}) err := detectCycles(n, c, []cycleEntry{ {Key: k, Func: n.Location()}, }, visited) if err != nil { err = errWrapf(err, "this function introduces a cycle") } return err } func detectCycles(n provider, c containerStore, path []cycleEntry, visited map[key]struct{}) error { var err error walkParam(n.ParamList(), paramVisitorFunc(func(param param) bool { if err != nil { return false } var ( k key providers []provider ) switch p := param.(type) { case paramSingle: k = key{name: p.Name, t: p.Type} if _, ok := visited[k]; ok { // We've already checked the dependencies for this type. return false } providers = c.getValueProviders(p.Name, p.Type) case paramGroupedSlice: // NOTE: The key uses the element type, not the slice type. k = key{group: p.Group, t: p.Type.Elem()} if _, ok := visited[k]; ok { // We've already checked the dependencies for this type. return false } providers = c.getGroupProviders(p.Group, p.Type.Elem()) default: // Recurse for non-edge params. return true } entry := cycleEntry{Func: n.Location(), Key: k} if len(path) > 0 { // Only mark a key as visited if path exists, i.e. this is not the // first iteration through the c.verifyAcyclic() check. Otherwise the // early exit from checking visited above will short circuit the // cycle check below. visited[k] = struct{}{} // If it exists, the first element of path is the new addition to the // graph, therefore it must be in any cycle that exists, assuming // verifyAcyclic has been run for every previous Provide. // // Alternatively, if deferAcyclicVerification was set and detectCycles // is only being called before the first Invoke, each node in the // graph will be tested as the first element of the path, so any // cycle that exists is guaranteed to trip the following condition. if path[0].Key == k { err = errCycleDetected{Path: append(path, entry)} return false } } for _, n := range providers { if e := detectCycles(n, c, append(path, entry), visited); e != nil { err = e return false } } return true })) return err } ================================================ FILE: vendor/go.uber.org/dig/dig.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "errors" "fmt" "io" "math/rand" "reflect" "sort" "strconv" "strings" "text/template" "time" "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" ) const ( _optionalTag = "optional" _nameTag = "name" _groupTag = "group" ) // Unique identification of an object in the graph. type key struct { t reflect.Type // Only one of name or group will be set. name string group string } // Option configures a Container. It's included for future functionality; // currently, there are no concrete implementations. type Option interface { applyOption(*Container) } type optionFunc func(*Container) func (f optionFunc) applyOption(c *Container) { f(c) } type provideOptions struct { Name string } func (o *provideOptions) Validate() error { // Names must be representable inside a backquoted string. The only // limitation for raw string literals as per // https://golang.org/ref/spec#raw_string_lit is that they cannot contain // backquotes. if strings.ContainsRune(o.Name, '`') { return fmt.Errorf("invalid dig.Name(%q): names cannot contain backquotes", o.Name) } return nil } // A ProvideOption modifies the default behavior of Provide. type ProvideOption interface { applyProvideOption(*provideOptions) } type provideOptionFunc func(*provideOptions) func (f provideOptionFunc) applyProvideOption(opts *provideOptions) { f(opts) } // Name is a ProvideOption that specifies that all values produced by a // constructor should have the given name. See also the package documentation // about Named Values. // // Given, // // func NewReadOnlyConnection(...) (*Connection, error) // func NewReadWriteConnection(...) (*Connection, error) // // The following will provide two connections to the container: one under the // name "ro" and the other under the name "rw". // // c.Provide(NewReadOnlyConnection, dig.Name("ro")) // c.Provide(NewReadWriteConnection, dig.Name("rw")) // // This option cannot be provided for constructors which produce result // objects. func Name(name string) ProvideOption { return provideOptionFunc(func(opts *provideOptions) { opts.Name = name }) } // An InvokeOption modifies the default behavior of Invoke. It's included for // future functionality; currently, there are no concrete implementations. type InvokeOption interface { unimplemented() } // Container is a directed acyclic graph of types and their dependencies. type Container struct { // Mapping from key to all the nodes that can provide a value for that // key. providers map[key][]*node // All nodes in the container. nodes []*node // Values that have already been generated in the container. values map[key]reflect.Value // Values groups that have already been generated in the container. groups map[key][]reflect.Value // Source of randomness. rand *rand.Rand // Flag indicating whether the graph has been checked for cycles. isVerifiedAcyclic bool // Defer acyclic check on provide until Invoke. deferAcyclicVerification bool } // containerWriter provides write access to the Container's underlying data // store. type containerWriter interface { // setValue sets the value with the given name and type in the container. // If a value with the same name and type already exists, it will be // overwritten. setValue(name string, t reflect.Type, v reflect.Value) // submitGroupedValue submits a value to the value group with the provided // name. submitGroupedValue(name string, t reflect.Type, v reflect.Value) } // containerStore provides access to the Container's underlying data store. type containerStore interface { containerWriter // Returns a slice containing all known types. knownTypes() []reflect.Type // Retrieves the value with the provided name and type, if any. getValue(name string, t reflect.Type) (v reflect.Value, ok bool) // Retrieves all values for the provided group and type. // // The order in which the values are returned is undefined. getValueGroup(name string, t reflect.Type) []reflect.Value // Returns the providers that can produce a value with the given name and // type. getValueProviders(name string, t reflect.Type) []provider // Returns the providers that can produce values for the given group and // type. getGroupProviders(name string, t reflect.Type) []provider createGraph() *dot.Graph } // provider encapsulates a user-provided constructor. type provider interface { // ID is a unique numerical identifier for this provider. ID() dot.CtorID // Location returns where this constructor was defined. Location() *digreflect.Func // ParamList returns information about the direct dependencies of this // constructor. ParamList() paramList // ResultList returns information about the values produced by this // constructor. ResultList() resultList // Calls the underlying constructor, reading values from the // containerStore as needed. // // The values produced by this provider should be submitted into the // containerStore. Call(containerStore) error } // New constructs a Container. func New(opts ...Option) *Container { c := &Container{ providers: make(map[key][]*node), values: make(map[key]reflect.Value), groups: make(map[key][]reflect.Value), rand: rand.New(rand.NewSource(time.Now().UnixNano())), } for _, opt := range opts { opt.applyOption(c) } return c } // DeferAcyclicVerification is an Option to override the default behavior // of container.Provide, deferring the dependency graph validation to no longer // run after each call to container.Provide. The container will instead verify // the graph on first `Invoke`. // // Applications adding providers to a container in a tight loop may experience // performance improvements by initializing the container with this option. func DeferAcyclicVerification() Option { return optionFunc(func(c *Container) { c.deferAcyclicVerification = true }) } // A VisualizeOption modifies the default behavior of Visualize. type VisualizeOption interface { applyVisualizeOption(*visualizeOptions) } type visualizeOptions struct { VisualizeError error } type visualizeOptionFunc func(*visualizeOptions) func (f visualizeOptionFunc) applyVisualizeOption(opts *visualizeOptions) { f(opts) } // VisualizeError includes a visualization of the given error in the output of // Visualize if an error was returned by Invoke or Provide. // // if err := c.Provide(...); err != nil { // dig.Visualize(c, w, dig.VisualizeError(err)) // } // // This option has no effect if the error was nil or if it didn't contain any // information to visualize. func VisualizeError(err error) VisualizeOption { return visualizeOptionFunc(func(opts *visualizeOptions) { opts.VisualizeError = err }) } func updateGraph(dg *dot.Graph, err error) error { var errors []errVisualizer // Unwrap error to find the root cause. for { if ev, ok := err.(errVisualizer); ok { errors = append(errors, ev) } e, ok := err.(causer) if !ok { break } err = e.cause() } // If there are no errVisualizers included, we do not modify the graph. if len(errors) == 0 { return nil } // We iterate in reverse because the last element is the root cause. for i := len(errors) - 1; i >= 0; i-- { errors[i].updateGraph(dg) } return nil } var _graphTmpl = template.Must( template.New("DotGraph"). Funcs(template.FuncMap{ "quote": strconv.Quote, }). Parse(`digraph { graph [compound=true]; {{range $g := .Groups}} {{- quote .String}} [{{.Attributes}}]; {{range .Results}} {{- quote $g.String}} -> {{quote .String}}; {{end}} {{end -}} {{range $index, $ctor := .Ctors}} subgraph cluster_{{$index}} { constructor_{{$index}} [shape=plaintext label={{quote .Name}}]; {{with .ErrorType}}color={{.Color}};{{end}} {{range .Results}} {{- quote .String}} [{{.Attributes}}]; {{end}} } {{range .Params}} constructor_{{$index}} -> {{quote .String}} [ltail=cluster_{{$index}}{{if .Optional}} style=dashed{{end}}]; {{end}} {{range .GroupParams}} constructor_{{$index}} -> {{quote .String}} [ltail=cluster_{{$index}}]; {{end -}} {{end}} {{range .Failed.TransitiveFailures}} {{- quote .String}} [color=orange]; {{end -}} {{range .Failed.RootCauses}} {{- quote .String}} [color=red]; {{end}} }`)) // Visualize parses the graph in Container c into DOT format and writes it to // io.Writer w. func Visualize(c *Container, w io.Writer, opts ...VisualizeOption) error { dg := c.createGraph() var options visualizeOptions for _, o := range opts { o.applyVisualizeOption(&options) } if options.VisualizeError != nil { if err := updateGraph(dg, options.VisualizeError); err != nil { return err } } return _graphTmpl.Execute(w, dg) } // CanVisualizeError returns true if the error is an errVisualizer. func CanVisualizeError(err error) bool { for { if _, ok := err.(errVisualizer); ok { return true } e, ok := err.(causer) if !ok { break } err = e.cause() } return false } func (c *Container) createGraph() *dot.Graph { dg := dot.NewGraph() for _, n := range c.nodes { dg.AddCtor(newDotCtor(n), n.paramList.DotParam(), n.resultList.DotResult()) } return dg } // Changes the source of randomness for the container. // // This will help provide determinism during tests. func setRand(r *rand.Rand) Option { return optionFunc(func(c *Container) { c.rand = r }) } func (c *Container) knownTypes() []reflect.Type { typeSet := make(map[reflect.Type]struct{}, len(c.providers)) for k := range c.providers { typeSet[k.t] = struct{}{} } types := make([]reflect.Type, 0, len(typeSet)) for t := range typeSet { types = append(types, t) } sort.Sort(byTypeName(types)) return types } func (c *Container) getValue(name string, t reflect.Type) (v reflect.Value, ok bool) { v, ok = c.values[key{name: name, t: t}] return } func (c *Container) setValue(name string, t reflect.Type, v reflect.Value) { c.values[key{name: name, t: t}] = v } func (c *Container) getValueGroup(name string, t reflect.Type) []reflect.Value { items := c.groups[key{group: name, t: t}] // shuffle the list so users don't rely on the ordering of grouped values return shuffledCopy(c.rand, items) } func (c *Container) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} c.groups[k] = append(c.groups[k], v) } func (c *Container) getValueProviders(name string, t reflect.Type) []provider { return c.getProviders(key{name: name, t: t}) } func (c *Container) getGroupProviders(name string, t reflect.Type) []provider { return c.getProviders(key{group: name, t: t}) } func (c *Container) getProviders(k key) []provider { nodes := c.providers[k] providers := make([]provider, len(nodes)) for i, n := range nodes { providers[i] = n } return providers } // Provide teaches the container how to build values of one or more types and // expresses their dependencies. // // The first argument of Provide is a function that accepts zero or more // parameters and returns one or more results. The function may optionally // return an error to indicate that it failed to build the value. This // function will be treated as the constructor for all the types it returns. // This function will be called AT MOST ONCE when a type produced by it, or a // type that consumes this function's output, is requested via Invoke. If the // same types are requested multiple times, the previously produced value will // be reused. // // In addition to accepting constructors that accept dependencies as separate // arguments and produce results as separate return values, Provide also // accepts constructors that specify dependencies as dig.In structs and/or // specify results as dig.Out structs. func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) error { ctype := reflect.TypeOf(constructor) if ctype == nil { return errors.New("can't provide an untyped nil") } if ctype.Kind() != reflect.Func { return fmt.Errorf("must provide constructor function, got %v (type %v)", constructor, ctype) } var options provideOptions for _, o := range opts { o.applyProvideOption(&options) } if err := options.Validate(); err != nil { return err } if err := c.provide(constructor, options); err != nil { return errProvide{ Func: digreflect.InspectFunc(constructor), Reason: err, } } return nil } // Invoke runs the given function after instantiating its dependencies. // // Any arguments that the function has are treated as its dependencies. The // dependencies are instantiated in an unspecified order along with any // dependencies that they might have. // // The function may return an error to indicate failure. The error will be // returned to the caller as-is. func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { ftype := reflect.TypeOf(function) if ftype == nil { return errors.New("can't invoke an untyped nil") } if ftype.Kind() != reflect.Func { return fmt.Errorf("can't invoke non-function %v (type %v)", function, ftype) } pl, err := newParamList(ftype) if err != nil { return err } if err := shallowCheckDependencies(c, pl); err != nil { return errMissingDependencies{ Func: digreflect.InspectFunc(function), Reason: err, } } if !c.isVerifiedAcyclic { if err := c.verifyAcyclic(); err != nil { return err } } args, err := pl.BuildList(c) if err != nil { return errArgumentsFailed{ Func: digreflect.InspectFunc(function), Reason: err, } } returned := reflect.ValueOf(function).Call(args) if len(returned) == 0 { return nil } if last := returned[len(returned)-1]; isError(last.Type()) { if err, _ := last.Interface().(error); err != nil { return err } } return nil } func (c *Container) verifyAcyclic() error { visited := make(map[key]struct{}) for _, n := range c.nodes { if err := detectCycles(n, c, nil /* path */, visited); err != nil { return errWrapf(err, "cycle detected in dependency graph") } } c.isVerifiedAcyclic = true return nil } func (c *Container) provide(ctor interface{}, opts provideOptions) error { n, err := newNode(ctor, nodeOptions{ResultName: opts.Name}) if err != nil { return err } keys, err := c.findAndValidateResults(n) if err != nil { return err } ctype := reflect.TypeOf(ctor) if len(keys) == 0 { return fmt.Errorf("%v must provide at least one non-error type", ctype) } for k := range keys { c.isVerifiedAcyclic = false oldProviders := c.providers[k] c.providers[k] = append(c.providers[k], n) if c.deferAcyclicVerification { continue } if err := verifyAcyclic(c, n, k); err != nil { c.providers[k] = oldProviders return err } c.isVerifiedAcyclic = true } c.nodes = append(c.nodes, n) return nil } // Builds a collection of all result types produced by this node. func (c *Container) findAndValidateResults(n *node) (map[key]struct{}, error) { var err error keyPaths := make(map[key]string) walkResult(n.ResultList(), connectionVisitor{ c: c, n: n, err: &err, keyPaths: keyPaths, }) if err != nil { return nil, err } keys := make(map[key]struct{}, len(keyPaths)) for k := range keyPaths { keys[k] = struct{}{} } return keys, nil } // Visits the results of a node and compiles a collection of all the keys // produced by that node. type connectionVisitor struct { c *Container n *node // If this points to a non-nil value, we've already encountered an error // and should stop traversing. err *error // Map of keys provided to path that provided this. The path is a string // documenting which positional return value or dig.Out attribute is // providing this particular key. // // For example, "[0].Foo" indicates that the value was provided by the Foo // attribute of the dig.Out returned as the first result of the // constructor. keyPaths map[key]string // We track the path to the current result here. For example, this will // be, ["[1]", "Foo", "Bar"] when we're visiting Bar in, // // func() (io.Writer, struct { // dig.Out // // Foo struct { // dig.Out // // Bar io.Reader // } // }) currentResultPath []string } func (cv connectionVisitor) AnnotateWithField(f resultObjectField) resultVisitor { cv.currentResultPath = append(cv.currentResultPath, f.FieldName) return cv } func (cv connectionVisitor) AnnotateWithPosition(i int) resultVisitor { cv.currentResultPath = append(cv.currentResultPath, fmt.Sprintf("[%d]", i)) return cv } func (cv connectionVisitor) Visit(res result) resultVisitor { // Already failed. Stop looking. if *cv.err != nil { return nil } path := strings.Join(cv.currentResultPath, ".") switch r := res.(type) { case resultSingle: k := key{name: r.Name, t: r.Type} if conflict, ok := cv.keyPaths[k]; ok { *cv.err = fmt.Errorf( "cannot provide %v from %v: already provided by %v", k, path, conflict) return nil } if ps := cv.c.providers[k]; len(ps) > 0 { cons := make([]string, len(ps)) for i, p := range ps { cons[i] = fmt.Sprint(p.Location()) } *cv.err = fmt.Errorf( "cannot provide %v from %v: already provided by %v", k, path, strings.Join(cons, "; ")) return nil } cv.keyPaths[k] = path case resultGrouped: // we don't really care about the path for this since conflicts are // okay for group results. We'll track it for the sake of having a // value there. k := key{group: r.Group, t: r.Type} cv.keyPaths[k] = path } return cv } // node is a node in the dependency graph. Each node maps to a single // constructor provided by the user. // // Nodes can produce zero or more values that they store into the container. // For the Provide path, we verify that nodes produce at least one value, // otherwise the function will never be called. type node struct { ctor interface{} ctype reflect.Type // Location where this function was defined. location *digreflect.Func // id uniquely identifies the constructor that produces a node. id dot.CtorID // Whether the constructor owned by this node was already called. called bool // Type information about constructor parameters. paramList paramList // Type information about constructor results. resultList resultList } type nodeOptions struct { // If specified, all values produced by this node have the provided name. ResultName string } func newNode(ctor interface{}, opts nodeOptions) (*node, error) { cval := reflect.ValueOf(ctor) ctype := cval.Type() cptr := cval.Pointer() params, err := newParamList(ctype) if err != nil { return nil, err } results, err := newResultList(ctype, resultOptions{Name: opts.ResultName}) if err != nil { return nil, err } return &node{ ctor: ctor, ctype: ctype, location: digreflect.InspectFunc(ctor), id: dot.CtorID(cptr), paramList: params, resultList: results, }, err } func (n *node) Location() *digreflect.Func { return n.location } func (n *node) ParamList() paramList { return n.paramList } func (n *node) ResultList() resultList { return n.resultList } func (n *node) ID() dot.CtorID { return n.id } // Call calls this node's constructor if it hasn't already been called and // injects any values produced by it into the provided container. func (n *node) Call(c containerStore) error { if n.called { return nil } if err := shallowCheckDependencies(c, n.paramList); err != nil { return errMissingDependencies{ Func: n.location, Reason: err, } } args, err := n.paramList.BuildList(c) if err != nil { return errArgumentsFailed{ Func: n.location, Reason: err, } } receiver := newStagingContainerWriter() results := reflect.ValueOf(n.ctor).Call(args) if err := n.resultList.ExtractList(receiver, results); err != nil { return errConstructorFailed{Func: n.location, Reason: err} } receiver.Commit(c) n.called = true return nil } // Checks if a field of an In struct is optional. func isFieldOptional(f reflect.StructField) (bool, error) { tag := f.Tag.Get(_optionalTag) if tag == "" { return false, nil } optional, err := strconv.ParseBool(tag) if err != nil { err = errWrapf(err, "invalid value %q for %q tag on field %v", tag, _optionalTag, f.Name) } return optional, err } // Checks that all direct dependencies of the provided param are present in // the container. Returns an error if not. func shallowCheckDependencies(c containerStore, p param) error { var missing errMissingManyTypes var addMissingNodes []*dot.Param walkParam(p, paramVisitorFunc(func(p param) bool { ps, ok := p.(paramSingle) if !ok { return true } if ns := c.getValueProviders(ps.Name, ps.Type); len(ns) == 0 && !ps.Optional { missing = append(missing, newErrMissingType(c, key{name: ps.Name, t: ps.Type})) addMissingNodes = append(addMissingNodes, ps.DotParam()...) } return true })) if len(missing) > 0 { return missing } return nil } // stagingContainerWriter is a containerWriter that records the changes that // would be made to a containerWriter and defers them until Commit is called. type stagingContainerWriter struct { values map[key]reflect.Value groups map[key][]reflect.Value } var _ containerWriter = (*stagingContainerWriter)(nil) func newStagingContainerWriter() *stagingContainerWriter { return &stagingContainerWriter{ values: make(map[key]reflect.Value), groups: make(map[key][]reflect.Value), } } func (sr *stagingContainerWriter) setValue(name string, t reflect.Type, v reflect.Value) { sr.values[key{t: t, name: name}] = v } func (sr *stagingContainerWriter) submitGroupedValue(group string, t reflect.Type, v reflect.Value) { k := key{t: t, group: group} sr.groups[k] = append(sr.groups[k], v) } // Commit commits the received results to the provided containerWriter. func (sr *stagingContainerWriter) Commit(cw containerWriter) { for k, v := range sr.values { cw.setValue(k.name, k.t, v) } for k, vs := range sr.groups { for _, v := range vs { cw.submitGroupedValue(k.group, k.t, v) } } } type byTypeName []reflect.Type func (bs byTypeName) Len() int { return len(bs) } func (bs byTypeName) Less(i int, j int) bool { return fmt.Sprint(bs[i]) < fmt.Sprint(bs[j]) } func (bs byTypeName) Swap(i int, j int) { bs[i], bs[j] = bs[j], bs[i] } func shuffledCopy(rand *rand.Rand, items []reflect.Value) []reflect.Value { newItems := make([]reflect.Value, len(items)) for i, j := range rand.Perm(len(items)) { newItems[i] = items[j] } return newItems } func newDotCtor(n *node) *dot.Ctor { return &dot.Ctor{ ID: n.id, Name: n.location.Name, Package: n.location.Package, File: n.location.File, Line: n.location.Line, } } ================================================ FILE: vendor/go.uber.org/dig/doc.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Package dig provides an opinionated way of resolving object dependencies. // // Status // // STABLE. No breaking changes will be made in this major version. // // Container // // Dig exposes type Container as an object capable of resolving a directed // acyclic dependency graph. Use the New function to create one. // // c := dig.New() // // Provide // // Constructors for different types are added to the container by using the // Provide method. A constructor can declare a dependency on another type by // simply adding it as a function parameter. Dependencies for a type can be // added to the graph both, before and after the type was added. // // err := c.Provide(func(conn *sql.DB) (*UserGateway, error) { // // ... // }) // if err != nil { // // ... // } // // if err := c.Provide(newDBConnection); err != nil { // // ... // } // // Multiple constructors can rely on the same type. The container creates a // singleton for each retained type, instantiating it at most once when // requested directly or as a dependency of another type. // // err := c.Provide(func(conn *sql.DB) *CommentGateway { // // ... // }) // if err != nil { // // ... // } // // Constructors can declare any number of dependencies as parameters and // optionally, return errors. // // err := c.Provide(func(u *UserGateway, c *CommentGateway) (*RequestHandler, error) { // // ... // }) // if err != nil { // // ... // } // // if err := c.Provide(newHTTPServer); err != nil { // // ... // } // // Constructors can also return multiple results to add multiple types to the // container. // // err := c.Provide(func(conn *sql.DB) (*UserGateway, *CommentGateway, error) { // // ... // }) // if err != nil { // // ... // } // // Constructors that accept a variadic number of arguments are treated as if // they don't have those arguments. That is, // // func NewVoteGateway(db *sql.DB, options ...Option) *VoteGateway // // Is treated the same as, // // func NewVoteGateway(db *sql.DB) *VoteGateway // // The constructor will be called with all other dependencies and no variadic // arguments. // // Invoke // // Types added to to the container may be consumed by using the Invoke method. // Invoke accepts any function that accepts one or more parameters and // optionally, returns an error. Dig calls the function with the requested // type, instantiating only those types that were requested by the function. // The call fails if any type or its dependencies (both direct and transitive) // were not available in the container. // // err := c.Invoke(func(l *log.Logger) { // // ... // }) // if err != nil { // // ... // } // // err := c.Invoke(func(server *http.Server) error { // // ... // }) // if err != nil { // // ... // } // // Any error returned by the invoked function is propagated back to the // caller. // // Parameter Objects // // Constructors declare their dependencies as function parameters. This can // very quickly become unreadable if the constructor has a lot of // dependencies. // // func NewHandler(users *UserGateway, comments *CommentGateway, posts *PostGateway, votes *VoteGateway, authz *AuthZGateway) *Handler { // // ... // } // // A pattern employed to improve readability in a situation like this is to // create a struct that lists all the parameters of the function as fields and // changing the function to accept that struct instead. This is referred to as // a parameter object. // // Dig has first class support for parameter objects: any struct embedding // dig.In gets treated as a parameter object. The following is equivalent to // the constructor above. // // type HandlerParams struct { // dig.In // // Users *UserGateway // Comments *CommentGateway // Posts *PostGateway // Votes *VoteGateway // AuthZ *AuthZGateway // } // // func NewHandler(p HandlerParams) *Handler { // // ... // } // // Handlers can receive any combination of parameter objects and parameters. // // func NewHandler(p HandlerParams, l *log.Logger) *Handler { // // ... // } // // Result Objects // // Result objects are the flip side of parameter objects. These are structs // that represent multiple outputs from a single function as fields in the // struct. Structs embedding dig.Out get treated as result objects. // // func SetupGateways(conn *sql.DB) (*UserGateway, *CommentGateway, *PostGateway, error) { // // ... // } // // The above is equivalent to, // // type Gateways struct { // dig.Out // // Users *UserGateway // Comments *CommentGateway // Posts *PostGateway // } // // func SetupGateways(conn *sql.DB) (Gateways, error) { // // ... // } // // Optional Dependencies // // Constructors often don't have a hard dependency on some types and // are able to operate in a degraded state when that dependency is missing. // Dig supports declaring dependencies as optional by adding an // `optional:"true"` tag to fields of a dig.In struct. // // Fields in a dig.In structs that have the `optional:"true"` tag are treated // as optional by Dig. // // type UserGatewayParams struct { // dig.In // // Conn *sql.DB // Cache *redis.Client `optional:"true"` // } // // If an optional field is not available in the container, the constructor // will receive a zero value for the field. // // func NewUserGateway(p UserGatewayParams, log *log.Logger) (*UserGateway, error) { // if p.Cache != nil { // log.Print("Logging disabled") // } // // ... // } // // Constructors that declare dependencies as optional MUST handle the case of // those dependencies being absent. // // The optional tag also allows adding new dependencies without breaking // existing consumers of the constructor. // // Named Values // // Some use cases call for multiple values of the same type. Dig allows adding // multiple values of the same type to the container with the use of Named // Values. // // Named Values can be produced by passing the dig.Name option when a // constructor is provided. All values produced by that constructor will have // the given name. // // Given the following constructors, // // func NewReadOnlyConnection(...) (*sql.DB, error) // func NewReadWriteConnection(...) (*sql.DB, error) // // You can provide *sql.DB into a Container under different names by passing // the dig.Name option. // // c.Provide(NewReadOnlyConnection, dig.Name("ro")) // c.Provide(NewReadWriteConnection, dig.Name("rw")) // // Alternatively, you can produce a dig.Out struct and tag its fields with // `name:".."` to have the corresponding value added to the graph under the // specified name. // // type ConnectionResult struct { // dig.Out // // ReadWrite *sql.DB `name:"rw"` // ReadOnly *sql.DB `name:"ro"` // } // // func ConnectToDatabase(...) (ConnectionResult, error) { // // ... // return ConnectionResult{ReadWrite: rw, ReadOnly: ro}, nil // } // // Regardless of how a Named Value was produced, it can be consumed by another // constructor by accepting a dig.In struct which has exported fields with the // same name AND type that you provided. // // type GatewayParams struct { // dig.In // // WriteToConn *sql.DB `name:"rw"` // ReadFromConn *sql.DB `name:"ro"` // } // // The name tag may be combined with the optional tag to declare the // dependency optional. // // type GatewayParams struct { // dig.In // // WriteToConn *sql.DB `name:"rw"` // ReadFromConn *sql.DB `name:"ro" optional:"true"` // } // // func NewCommentGateway(p GatewayParams, log *log.Logger) (*CommentGateway, error) { // if p.ReadFromConn == nil { // log.Print("Warning: Using RW connection for reads") // p.ReadFromConn = p.WriteToConn // } // // ... // } // // Value Groups // // Added in Dig 1.2. // // Dig provides value groups to allow producing and consuming many values of // the same type. Value groups allow constructors to send values to a named, // unordered collection in the container. Other constructors can request all // values in this collection as a slice. // // Constructors can send values into value groups by returning a dig.Out // struct tagged with `group:".."`. // // type HandlerResult struct { // dig.Out // // Handler Handler `group:"server"` // } // // func NewHelloHandler() HandlerResult { // .. // } // // func NewEchoHandler() HandlerResult { // .. // } // // Any number of constructors may provide values to this named collection. // Other constructors can request all values for this collection by requesting // a slice tagged with `group:".."`. This will execute all constructors that // provide a value to that group in an unspecified order. // // type ServerParams struct { // dig.In // // Handlers []Handler `group:"server"` // } // // func NewServer(p ServerParams) *Server { // server := newServer() // for _, h := range p.Handlers { // server.Register(h) // } // return server // } // // Note that values in a value group are unordered. Dig makes no guarantees // about the order in which these values will be produced. package dig // import "go.uber.org/dig" ================================================ FILE: vendor/go.uber.org/dig/error.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "bytes" "fmt" "reflect" "sort" "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" ) // Errors which know their underlying cause should implement this interface to // be compatible with RootCause. // // We use an unexported "cause" method instead of "Cause" because we don't // want dig-internal causes to be confused with the cause of the user-provided // errors. (For example, if the users are using github.com/pkg/errors.) type causer interface { cause() error } // RootCause returns the original error that caused the provided dig failure. // // RootCause may be used on errors returned by Invoke to get the original // error returned by a constructor or invoked function. func RootCause(err error) error { for { if e, ok := err.(causer); ok { err = e.cause() } else { return err } } } // errWrapf wraps an existing error with more contextual information. // // The given error is treated as the cause of the returned error (see causer). // // RootCause(errWrapf(errWrapf(err, ...), ...)) == err // // Use errWrapf instead of fmt.Errorf if the message ends with ": ". func errWrapf(err error, msg string, args ...interface{}) error { if err == nil { return nil } if len(args) > 0 { msg = fmt.Sprintf(msg, args...) } return wrappedError{err: err, msg: msg} } type wrappedError struct { err error msg string } func (e wrappedError) cause() error { return e.err } func (e wrappedError) Error() string { return fmt.Sprintf("%v: %v", e.msg, e.err) } // errProvide is returned when a constructor could not be Provided into the // container. type errProvide struct { Func *digreflect.Func Reason error } func (e errProvide) cause() error { return e.Reason } func (e errProvide) Error() string { return fmt.Sprintf("function %v cannot be provided: %v", e.Func, e.Reason) } // errConstructorFailed is returned when a user-provided constructor failed // with a non-nil error. type errConstructorFailed struct { Func *digreflect.Func Reason error } func (e errConstructorFailed) cause() error { return e.Reason } func (e errConstructorFailed) Error() string { return fmt.Sprintf("function %v returned a non-nil error: %v", e.Func, e.Reason) } // errArgumentsFailed is returned when a function could not be run because one // of its dependencies failed to build for any reason. type errArgumentsFailed struct { Func *digreflect.Func Reason error } func (e errArgumentsFailed) cause() error { return e.Reason } func (e errArgumentsFailed) Error() string { return fmt.Sprintf("could not build arguments for function %v: %v", e.Func, e.Reason) } // errMissingDependencies is returned when the dependencies of a function are // not available in the container. type errMissingDependencies struct { Func *digreflect.Func Reason error } func (e errMissingDependencies) cause() error { return e.Reason } func (e errMissingDependencies) Error() string { return fmt.Sprintf("missing dependencies for function %v: %v", e.Func, e.Reason) } // errParamSingleFailed is returned when a paramSingle could not be built. type errParamSingleFailed struct { Key key Reason error CtorID dot.CtorID } func (e errParamSingleFailed) cause() error { return e.Reason } func (e errParamSingleFailed) Error() string { return fmt.Sprintf("failed to build %v: %v", e.Key, e.Reason) } func (e errParamSingleFailed) updateGraph(g *dot.Graph) { failed := &dot.Result{ Node: &dot.Node{ Name: e.Key.name, Group: e.Key.group, Type: e.Key.t, }, } g.FailNodes([]*dot.Result{failed}, e.CtorID) } // errParamGroupFailed is returned when a value group cannot be built because // any of the values in the group failed to build. type errParamGroupFailed struct { Key key Reason error CtorID dot.CtorID } func (e errParamGroupFailed) cause() error { return e.Reason } func (e errParamGroupFailed) Error() string { return fmt.Sprintf("could not build value group %v: %v", e.Key, e.Reason) } func (e errParamGroupFailed) updateGraph(g *dot.Graph) { g.FailGroupNodes(e.Key.group, e.Key.t, e.CtorID) } // errMissingType is returned when a single value that was expected in the // container was not available. type errMissingType struct { Key key // If non-empty, we will include suggestions for what the user may have // meant. suggestions []key } func newErrMissingType(c containerStore, k key) errMissingType { // Possible types we will look for in the container. We will always look // for pointers to the requested type and some extras on a per-Kind basis. suggestions := []reflect.Type{reflect.PtrTo(k.t)} if k.t.Kind() == reflect.Ptr { // The user requested a pointer but maybe we have a value. suggestions = append(suggestions, k.t.Elem()) } knownTypes := c.knownTypes() if k.t.Kind() == reflect.Interface { // Maybe we have an implementation of the interface. for _, t := range knownTypes { if t.Implements(k.t) { suggestions = append(suggestions, t) } } } else { // Maybe we have an interface that this type implements. for _, t := range knownTypes { if t.Kind() == reflect.Interface { if k.t.Implements(t) { suggestions = append(suggestions, t) } } } } // range through c.providers is non-deterministic. Let's sort the list of // suggestions. sort.Sort(byTypeName(suggestions)) err := errMissingType{Key: k} for _, t := range suggestions { if len(c.getValueProviders(k.name, t)) > 0 { k.t = t err.suggestions = append(err.suggestions, k) } } return err } func (e errMissingType) Error() string { // Sample messages: // // type io.Reader is not in the container, did you mean to Provide it? // type io.Reader is not in the container, did you mean to use one of *bytes.Buffer, *MyBuffer // type bytes.Buffer is not in the container, did you mean to use *bytes.Buffer? // type *foo[name="bar"] is not in the container, did you mean to use foo[name="bar"]? b := new(bytes.Buffer) fmt.Fprintf(b, "type %v is not in the container", e.Key) switch len(e.suggestions) { case 0: b.WriteString(", did you mean to Provide it?") case 1: fmt.Fprintf(b, ", did you mean to use %v?", e.suggestions[0]) default: b.WriteString(", did you mean to use one of ") for i, k := range e.suggestions { if i > 0 { b.WriteString(", ") if i == len(e.suggestions)-1 { b.WriteString("or ") } } fmt.Fprint(b, k) } b.WriteString("?") } return b.String() } // errMissingManyTypes combines multiple errMissingType errors. type errMissingManyTypes []errMissingType // length must be non-zero func (e errMissingManyTypes) Error() string { if len(e) == 1 { return e[0].Error() } b := new(bytes.Buffer) b.WriteString("the following types are not in the container: ") for i, err := range e { if i > 0 { b.WriteString("; ") } fmt.Fprintf(b, "%v", err.Key) switch len(err.suggestions) { case 0: // do nothing case 1: fmt.Fprintf(b, " (did you mean %v?)", err.suggestions[0]) default: b.WriteString(" (did you mean ") for i, k := range err.suggestions { if i > 0 { b.WriteString(", ") if i == len(err.suggestions)-1 { b.WriteString("or ") } } fmt.Fprint(b, k) } b.WriteString("?)") } } return b.String() } func (e errMissingManyTypes) updateGraph(g *dot.Graph) { missing := make([]*dot.Result, len(e)) for i, err := range e { missing[i] = &dot.Result{ Node: &dot.Node{ Name: err.Key.name, Group: err.Key.group, Type: err.Key.t, }, } } g.AddMissingNodes(missing) } type errVisualizer interface { updateGraph(*dot.Graph) } ================================================ FILE: vendor/go.uber.org/dig/glide.yaml ================================================ package: go.uber.org/dig license: MIT testImport: - package: github.com/stretchr/testify subpackages: - assert - require ================================================ FILE: vendor/go.uber.org/dig/internal/digreflect/func.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package digreflect import ( "fmt" "net/url" "reflect" "runtime" "strings" ) // Func contains runtime information about a function. type Func struct { // Name of the function. Name string // Name of the package in which this function is defined. Package string // Path to the file in which this function is defined. File string // Line number in the file at which this function is defined. Line int } // String returns a string representation of the function. func (f *Func) String() string { // "path/to/package".MyFunction (path/to/file.go:42) return fmt.Sprintf("%q.%v (%v:%v)", f.Package, f.Name, f.File, f.Line) } // InspectFunc inspects and returns runtime information about the given // function. func InspectFunc(function interface{}) *Func { fptr := reflect.ValueOf(function).Pointer() f := runtime.FuncForPC(fptr) pkgName, funcName := splitFuncName(f.Name()) fileName, lineNum := f.FileLine(fptr) return &Func{ Name: funcName, Package: pkgName, File: fileName, Line: lineNum, } } const _vendor = "/vendor/" func splitFuncName(function string) (pname string, fname string) { if len(function) == 0 { return } // We have something like "path.to/my/pkg.MyFunction". If the function is // a closure, it is something like, "path.to/my/pkg.MyFunction.func1". idx := 0 // Everything up to the first "." after the last "/" is the package name. // Everything after the "." is the full function name. if i := strings.LastIndex(function, "/"); i >= 0 { idx = i } if i := strings.Index(function[idx:], "."); i >= 0 { idx += i } pname, fname = function[:idx], function[idx+1:] // The package may be vendored. if i := strings.Index(pname, _vendor); i > 0 { pname = pname[i+len(_vendor):] } // Package names are URL-encoded to avoid ambiguity in the case where the // package name contains ".git". Otherwise, "foo/bar.git.MyFunction" would // mean that "git" is the top-level function and "MyFunction" is embedded // inside it. if unescaped, err := url.QueryUnescape(pname); err == nil { pname = unescaped } return } ================================================ FILE: vendor/go.uber.org/dig/internal/dot/graph.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dot import ( "fmt" "reflect" ) // ErrorType of a constructor or group is updated when they fail to build. type ErrorType int const ( noError ErrorType = iota rootCause transitiveFailure ) // CtorID is a unique numeric identifier for constructors. type CtorID uintptr // Ctor encodes a constructor provided to the container for the DOT graph. type Ctor struct { Name string Package string File string Line int ID CtorID Params []*Param GroupParams []*Group Results []*Result ErrorType ErrorType } // Node is a single node in a graph and is embedded into Params and Results. type Node struct { Type reflect.Type Name string Group string } // Param is a parameter node in the graph. type Param struct { *Node Optional bool } // Result is a result node in the graph. type Result struct { *Node // GroupIndex is added to differentiate grouped values from one another. // Since grouped values have the same type and group, their Node / string // representations are the same so we need indices to uniquely identify // the values. GroupIndex int } // Group is a group node in the graph. type Group struct { // Type is the type of values in the group. Type reflect.Type Name string Results []*Result ErrorType ErrorType } // Graph is the DOT-format graph in a Container. type Graph struct { Ctors []*Ctor ctorMap map[CtorID]*Ctor Groups []*Group groupMap map[groupKey]*Group Failed *FailedNodes } // FailedNodes is the nodes that failed in the graph. type FailedNodes struct { // RootCauses is a list of the point of failures. They are the root causes // of failed invokes and can be either missing types (not provided) or // error types (error providing). RootCauses []*Result // TransitiveFailures is the list of nodes that failed to build due to // missing/failed dependencies. TransitiveFailures []*Result } type groupKey struct { t reflect.Type group string } // NewGraph creates an empty graph. func NewGraph() *Graph { return &Graph{ ctorMap: make(map[CtorID]*Ctor), groupMap: make(map[groupKey]*Group), Failed: &FailedNodes{}, } } // NewGroup creates a new group with information in the groupKey. func NewGroup(k groupKey) *Group { return &Group{ Type: k.t, Name: k.group, } } // AddCtor adds the constructor with paramList and resultList into the graph. func (dg *Graph) AddCtor(c *Ctor, paramList []*Param, resultList []*Result) { var ( params []*Param groupParams []*Group ) // Loop through the paramList to separate them into regular params and // grouped params. For grouped params, we use getGroup to find the actual // group. for _, param := range paramList { if param.Group == "" { // Not a value group. params = append(params, param) continue } k := groupKey{t: param.Type.Elem(), group: param.Group} group := dg.getGroup(k) groupParams = append(groupParams, group) } for _, result := range resultList { // If the result is a grouped value, we want to update its GroupIndex // and add it to the Group. if result.Group != "" { dg.addToGroup(result, c.ID) } } c.Params = params c.GroupParams = groupParams c.Results = resultList dg.Ctors = append(dg.Ctors, c) dg.ctorMap[c.ID] = c } func (dg *Graph) failNode(r *Result, isRootCause bool) { if isRootCause { dg.addRootCause(r) } else { dg.addTransitiveFailure(r) } } // AddMissingNodes adds missing nodes to the list of failed Results in the graph. func (dg *Graph) AddMissingNodes(results []*Result) { // The failure(s) are root causes if there are no other failures. isRootCause := len(dg.Failed.RootCauses) == 0 for _, r := range results { dg.failNode(r, isRootCause) } } // FailNodes adds results to the list of failed Results in the graph, and // updates the state of the constructor with the given id accordingly. func (dg *Graph) FailNodes(results []*Result, id CtorID) { // This failure is the root cause if there are no other failures. isRootCause := len(dg.Failed.RootCauses) == 0 for _, r := range results { dg.failNode(r, isRootCause) } if c, ok := dg.ctorMap[id]; ok { if isRootCause { c.ErrorType = rootCause } else { c.ErrorType = transitiveFailure } } } // FailGroupNodes finds and adds the failed grouped nodes to the list of failed // Results in the graph, and updates the state of the group and constructor // with the given id accordingly. func (dg *Graph) FailGroupNodes(name string, t reflect.Type, id CtorID) { // This failure is the root cause if there are no other failures. isRootCause := len(dg.Failed.RootCauses) == 0 k := groupKey{t: t, group: name} group := dg.getGroup(k) for _, r := range dg.ctorMap[id].Results { if r.Type == t && r.Group == name { dg.failNode(r, isRootCause) } } if c, ok := dg.ctorMap[id]; ok { if isRootCause { group.ErrorType = rootCause c.ErrorType = rootCause } else { group.ErrorType = transitiveFailure c.ErrorType = transitiveFailure } } } // getGroup finds the group by groupKey from the graph. If it is not available, // a new group is created and returned. func (dg *Graph) getGroup(k groupKey) *Group { g, ok := dg.groupMap[k] if !ok { g = NewGroup(k) dg.groupMap[k] = g dg.Groups = append(dg.Groups, g) } return g } // addToGroup adds a newly provided grouped result to the appropriate group. func (dg *Graph) addToGroup(r *Result, id CtorID) { k := groupKey{t: r.Type, group: r.Group} group := dg.getGroup(k) r.GroupIndex = len(group.Results) group.Results = append(group.Results, r) } // String implements fmt.Stringer for Param. func (p *Param) String() string { if p.Name != "" { return fmt.Sprintf("%v[name=%v]", p.Type.String(), p.Name) } return p.Type.String() } // String implements fmt.Stringer for Result. func (r *Result) String() string { switch { case r.Name != "": return fmt.Sprintf("%v[name=%v]", r.Type.String(), r.Name) case r.Group != "": return fmt.Sprintf("%v[group=%v]%v", r.Type.String(), r.Group, r.GroupIndex) default: return r.Type.String() } } // String implements fmt.Stringer for Group. func (g *Group) String() string { return fmt.Sprintf("[type=%v group=%v]", g.Type.String(), g.Name) } // Attributes composes and returns a string of the Result node's attributes. func (r *Result) Attributes() string { switch { case r.Name != "": return fmt.Sprintf(`label=<%v
Name: %v>`, r.Type, r.Name) case r.Group != "": return fmt.Sprintf(`label=<%v
Group: %v>`, r.Type, r.Group) default: return fmt.Sprintf(`label=<%v>`, r.Type) } } // Attributes composes and returns a string of the Group node's attributes. func (g *Group) Attributes() string { attr := fmt.Sprintf(`shape=diamond label=<%v
Group: %v>`, g.Type, g.Name) if g.ErrorType != noError { attr += " color=" + g.ErrorType.Color() } return attr } // Color returns the color representation of each ErrorType. func (s ErrorType) Color() string { switch s { case rootCause: return "red" case transitiveFailure: return "orange" default: return "black" } } func (dg *Graph) addRootCause(r *Result) { dg.Failed.RootCauses = append(dg.Failed.RootCauses, r) } func (dg *Graph) addTransitiveFailure(r *Result) { dg.Failed.TransitiveFailures = append(dg.Failed.TransitiveFailures, r) } ================================================ FILE: vendor/go.uber.org/dig/param.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "errors" "fmt" "reflect" "go.uber.org/dig/internal/dot" ) // The param interface represents a dependency for a constructor. // // The following implementations exist: // paramList All arguments of the constructor. // paramSingle An explicitly requested type. // paramObject dig.In struct where each field in the struct can be another // param. // paramGroupedSlice // A slice consuming a value group. This will receive all // values produced with a `group:".."` tag with the same name // as a slice. type param interface { fmt.Stringer // Builds this dependency and any of its dependencies from the provided // Container. // // This MAY panic if the param does not produce a single value. Build(containerStore) (reflect.Value, error) // DotParam returns a slice of dot.Param(s). DotParam() []*dot.Param } var ( _ param = paramSingle{} _ param = paramObject{} _ param = paramList{} _ param = paramGroupedSlice{} ) // newParam builds a param from the given type. If the provided type is a // dig.In struct, an paramObject will be returned. func newParam(t reflect.Type) (param, error) { switch { case IsOut(t) || (t.Kind() == reflect.Ptr && IsOut(t.Elem())) || embedsType(t, _outPtrType): return nil, fmt.Errorf("cannot depend on result objects: %v embeds a dig.Out", t) case IsIn(t): return newParamObject(t) case embedsType(t, _inPtrType): return nil, fmt.Errorf( "cannot build a parameter object by embedding *dig.In, embed dig.In instead: "+ "%v embeds *dig.In", t) case t.Kind() == reflect.Ptr && IsIn(t.Elem()): return nil, fmt.Errorf( "cannot depend on a pointer to a parameter object, use a value instead: "+ "%v is a pointer to a struct that embeds dig.In", t) default: return paramSingle{Type: t}, nil } } // paramVisitor visits every param in a param tree, allowing tracking state at // each level. type paramVisitor interface { // Visit is called on the param being visited. // // If Visit returns a non-nil paramVisitor, that paramVisitor visits all // the child params of this param. Visit(param) paramVisitor // We can implement AnnotateWithField and AnnotateWithPosition like // resultVisitor if we need to track that information in the future. } // paramVisitorFunc is a paramVisitor that visits param in a tree with the // return value deciding whether the descendants of this param should be // recursed into. type paramVisitorFunc func(param) (recurse bool) func (f paramVisitorFunc) Visit(p param) paramVisitor { if f(p) { return f } return nil } // walkParam walks the param tree for the given param with the provided // visitor. // // paramVisitor.Visit will be called on the provided param and if a non-nil // paramVisitor is received, this param's descendants will be walked with that // visitor. // // This is very similar to how go/ast.Walk works. func walkParam(p param, v paramVisitor) { v = v.Visit(p) if v == nil { return } switch par := p.(type) { case paramSingle, paramGroupedSlice: // No sub-results case paramObject: for _, f := range par.Fields { walkParam(f.Param, v) } case paramList: for _, p := range par.Params { walkParam(p, v) } default: panic(fmt.Sprintf( "It looks like you have found a bug in dig. "+ "Please file an issue at https://github.com/uber-go/dig/issues/ "+ "and provide the following message: "+ "received unknown param type %T", p)) } } // paramList holds all arguments of the constructor as params. // // NOTE: Build() MUST NOT be called on paramList. Instead, BuildList // must be called. type paramList struct { ctype reflect.Type // type of the constructor Params []param } func (pl paramList) DotParam() []*dot.Param { var types []*dot.Param for _, param := range pl.Params { types = append(types, param.DotParam()...) } return types } // newParamList builds a paramList from the provided constructor type. // // Variadic arguments of a constructor are ignored and not included as // dependencies. func newParamList(ctype reflect.Type) (paramList, error) { numArgs := ctype.NumIn() if ctype.IsVariadic() { // NOTE: If the function is variadic, we skip the last argument // because we're not filling variadic arguments yet. See #120. numArgs-- } pl := paramList{ ctype: ctype, Params: make([]param, 0, numArgs), } for i := 0; i < numArgs; i++ { p, err := newParam(ctype.In(i)) if err != nil { return pl, errWrapf(err, "bad argument %d", i+1) } pl.Params = append(pl.Params, p) } return pl, nil } func (pl paramList) Build(containerStore) (reflect.Value, error) { panic("It looks like you have found a bug in dig. " + "Please file an issue at https://github.com/uber-go/dig/issues/ " + "and provide the following message: " + "paramList.Build() must never be called") } // BuildList returns an ordered list of values which may be passed directly // to the underlying constructor. func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { args := make([]reflect.Value, len(pl.Params)) for i, p := range pl.Params { var err error args[i], err = p.Build(c) if err != nil { return nil, err } } return args, nil } // paramSingle is an explicitly requested type, optionally with a name. // // This object must be present in the graph as-is unless it's specified as // optional. type paramSingle struct { Name string Optional bool Type reflect.Type } func (ps paramSingle) DotParam() []*dot.Param { return []*dot.Param{ { Node: &dot.Node{ Type: ps.Type, Name: ps.Name, }, Optional: ps.Optional, }, } } func (ps paramSingle) Build(c containerStore) (reflect.Value, error) { if v, ok := c.getValue(ps.Name, ps.Type); ok { return v, nil } providers := c.getValueProviders(ps.Name, ps.Type) if len(providers) == 0 { if ps.Optional { return reflect.Zero(ps.Type), nil } return _noValue, newErrMissingType(c, key{name: ps.Name, t: ps.Type}) } for _, n := range providers { err := n.Call(c) if err == nil { continue } // If we're missing dependencies but the parameter itself is optional, // we can just move on. if _, ok := err.(errMissingDependencies); ok && ps.Optional { return reflect.Zero(ps.Type), nil } return _noValue, errParamSingleFailed{ CtorID: n.ID(), Key: key{t: ps.Type, name: ps.Name}, Reason: err, } } // If we get here, it's impossible for the value to be absent from the // container. v, _ := c.getValue(ps.Name, ps.Type) return v, nil } // paramObject is a dig.In struct where each field is another param. // // This object is not expected in the graph as-is. type paramObject struct { Type reflect.Type Fields []paramObjectField } func (po paramObject) DotParam() []*dot.Param { var types []*dot.Param for _, field := range po.Fields { types = append(types, field.DotParam()...) } return types } // newParamObject builds an paramObject from the provided type. The type MUST // be a dig.In struct. func newParamObject(t reflect.Type) (paramObject, error) { po := paramObject{Type: t} for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Type == _inType { // Skip over the dig.In embed. continue } pof, err := newParamObjectField(i, f) if err != nil { return po, errWrapf(err, "bad field %q of %v", f.Name, t) } po.Fields = append(po.Fields, pof) } return po, nil } func (po paramObject) Build(c containerStore) (reflect.Value, error) { dest := reflect.New(po.Type).Elem() for _, f := range po.Fields { v, err := f.Build(c) if err != nil { return dest, err } dest.Field(f.FieldIndex).Set(v) } return dest, nil } // paramObjectField is a single field of a dig.In struct. type paramObjectField struct { // Name of the field in the struct. FieldName string // Index of this field in the target struct. // // We need to track this separately because not all fields of the // struct map to params. FieldIndex int // The dependency requested by this field. Param param } func (pof paramObjectField) DotParam() []*dot.Param { return pof.Param.DotParam() } func newParamObjectField(idx int, f reflect.StructField) (paramObjectField, error) { pof := paramObjectField{ FieldName: f.Name, FieldIndex: idx, } var p param switch { case f.PkgPath != "": return pof, fmt.Errorf( "unexported fields not allowed in dig.In, did you mean to export %q (%v)?", f.Name, f.Type) case f.Tag.Get(_groupTag) != "": var err error p, err = newParamGroupedSlice(f) if err != nil { return pof, err } default: var err error p, err = newParam(f.Type) if err != nil { return pof, err } } if ps, ok := p.(paramSingle); ok { ps.Name = f.Tag.Get(_nameTag) var err error ps.Optional, err = isFieldOptional(f) if err != nil { return pof, err } p = ps } pof.Param = p return pof, nil } func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { v, err := pof.Param.Build(c) if err != nil { return v, err } return v, nil } // paramGroupedSlice is a param which produces a slice of values with the same // group name. type paramGroupedSlice struct { // Name of the group as specified in the `group:".."` tag. Group string // Type of the slice. Type reflect.Type } func (pt paramGroupedSlice) DotParam() []*dot.Param { return []*dot.Param{ { Node: &dot.Node{ Type: pt.Type, Group: pt.Group, }, }, } } // newParamGroupedSlice builds a paramGroupedSlice from the provided type with // the given name. // // The type MUST be a slice type. func newParamGroupedSlice(f reflect.StructField) (paramGroupedSlice, error) { pg := paramGroupedSlice{Group: f.Tag.Get(_groupTag), Type: f.Type} name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { case f.Type.Kind() != reflect.Slice: return pg, fmt.Errorf("value groups may be consumed as slices only: "+ "field %q (%v) is not a slice", f.Name, f.Type) case name != "": return pg, fmt.Errorf( "cannot use named values with value groups: name:%q requested with group:%q", name, pg.Group) case optional: return pg, errors.New("value groups cannot be optional") } return pg, nil } func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { for _, n := range c.getGroupProviders(pt.Group, pt.Type.Elem()) { if err := n.Call(c); err != nil { return _noValue, errParamGroupFailed{ CtorID: n.ID(), Key: key{group: pt.Group, t: pt.Type.Elem()}, Reason: err, } } } items := c.getValueGroup(pt.Group, pt.Type.Elem()) result := reflect.MakeSlice(pt.Type, len(items), len(items)) for i, v := range items { result.Index(i).Set(v) } return result, nil } ================================================ FILE: vendor/go.uber.org/dig/result.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "errors" "fmt" "reflect" "go.uber.org/dig/internal/dot" ) // The result interface represents a result produced by a constructor. // // The following implementations exist: // resultList All values returned by the constructor. // resultSingle A single value produced by a constructor. // resultObject dig.Out struct where each field in the struct can be // another result. // resultGrouped A value produced by a constructor that is part of a value // group. type result interface { // Extracts the values for this result from the provided value and // stores them into the provided containerWriter. // // This MAY panic if the result does not consume a single value. Extract(containerWriter, reflect.Value) // DotResult returns a slice of dot.Result(s). DotResult() []*dot.Result } var ( _ result = resultSingle{} _ result = resultObject{} _ result = resultList{} _ result = resultGrouped{} ) type resultOptions struct { // If set, this is the name of the associated result value. // // For Result Objects, name:".." tags on fields override this. Name string } // newResult builds a result from the given type. func newResult(t reflect.Type, opts resultOptions) (result, error) { switch { case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType): return nil, fmt.Errorf("cannot provide parameter objects: %v embeds a dig.In", t) case isError(t): return nil, fmt.Errorf("cannot return an error here, return it from the constructor instead") case IsOut(t): return newResultObject(t, opts) case embedsType(t, _outPtrType): return nil, fmt.Errorf( "cannot build a result object by embedding *dig.Out, embed dig.Out instead: "+ "%v embeds *dig.Out", t) case t.Kind() == reflect.Ptr && IsOut(t.Elem()): return nil, fmt.Errorf( "cannot return a pointer to a result object, use a value instead: "+ "%v is a pointer to a struct that embeds dig.Out", t) default: return resultSingle{Type: t, Name: opts.Name}, nil } } // resultVisitor visits every result in a result tree, allowing tracking state // at each level. type resultVisitor interface { // Visit is called on the result being visited. // // If Visit returns a non-nil resultVisitor, that resultVisitor visits all // the child results of this result. Visit(result) resultVisitor // AnnotateWithField is called on each field of a resultObject after // visiting it but before walking its descendants. // // The same resultVisitor is used for all fields: the one returned upon // visiting the resultObject. // // For each visited field, if AnnotateWithField returns a non-nil // resultVisitor, it will be used to walk the result of that field. AnnotateWithField(resultObjectField) resultVisitor // AnnotateWithPosition is called with the index of each result of a // resultList after vising it but before walking its descendants. // // The same resultVisitor is used for all results: the one returned upon // visiting the resultList. // // For each position, if AnnotateWithPosition returns a non-nil // resultVisitor, it will be used to walk the result at that index. AnnotateWithPosition(idx int) resultVisitor } // walkResult walks the result tree for the given result with the provided // visitor. // // resultVisitor.Visit will be called on the provided result and if a non-nil // resultVisitor is received, it will be used to walk its descendants. If a // resultObject or resultList was visited, AnnotateWithField and // AnnotateWithPosition respectively will be called before visiting the // descendants of that resultObject/resultList. // // This is very similar to how go/ast.Walk works. func walkResult(r result, v resultVisitor) { v = v.Visit(r) if v == nil { return } switch res := r.(type) { case resultSingle, resultGrouped: // No sub-results case resultObject: w := v for _, f := range res.Fields { if v := w.AnnotateWithField(f); v != nil { walkResult(f.Result, v) } } case resultList: w := v for i, r := range res.Results { if v := w.AnnotateWithPosition(i); v != nil { walkResult(r, v) } } default: panic(fmt.Sprintf( "It looks like you have found a bug in dig. "+ "Please file an issue at https://github.com/uber-go/dig/issues/ "+ "and provide the following message: "+ "received unknown result type %T", res)) } } // resultList holds all values returned by the constructor as results. type resultList struct { ctype reflect.Type Results []result // For each item at index i returned by the constructor, resultIndexes[i] // is the index in .Results for the corresponding result object. // resultIndexes[i] is -1 for errors returned by constructors. resultIndexes []int } func (rl resultList) DotResult() []*dot.Result { var types []*dot.Result for _, result := range rl.Results { types = append(types, result.DotResult()...) } return types } func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { rl := resultList{ ctype: ctype, Results: make([]result, 0, ctype.NumOut()), resultIndexes: make([]int, ctype.NumOut()), } resultIdx := 0 for i := 0; i < ctype.NumOut(); i++ { t := ctype.Out(i) if isError(t) { rl.resultIndexes[i] = -1 continue } r, err := newResult(t, opts) if err != nil { return rl, errWrapf(err, "bad result %d", i+1) } rl.Results = append(rl.Results, r) rl.resultIndexes[i] = resultIdx resultIdx++ } return rl, nil } func (resultList) Extract(containerWriter, reflect.Value) { panic("It looks like you have found a bug in dig. " + "Please file an issue at https://github.com/uber-go/dig/issues/ " + "and provide the following message: " + "resultList.Extract() must never be called") } func (rl resultList) ExtractList(cw containerWriter, values []reflect.Value) error { for i, v := range values { if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { rl.Results[resultIdx].Extract(cw, v) continue } if err, _ := v.Interface().(error); err != nil { return err } } return nil } // resultSingle is an explicit value produced by a constructor, optionally // with a name. // // This object will be added to the graph as-is. type resultSingle struct { Name string Type reflect.Type } func (rs resultSingle) DotResult() []*dot.Result { return []*dot.Result{ { Node: &dot.Node{ Type: rs.Type, Name: rs.Name, }, }, } } func (rs resultSingle) Extract(cw containerWriter, v reflect.Value) { cw.setValue(rs.Name, rs.Type, v) } // resultObject is a dig.Out struct where each field is another result. // // This object is not added to the graph. Its fields are interpreted as // results and added to the graph if needed. type resultObject struct { Type reflect.Type Fields []resultObjectField } func (ro resultObject) DotResult() []*dot.Result { var types []*dot.Result for _, field := range ro.Fields { types = append(types, field.DotResult()...) } return types } func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { ro := resultObject{Type: t} if len(opts.Name) > 0 { return ro, fmt.Errorf( "cannot specify a name for result objects: %v embeds dig.Out", t) } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Type == _outType { // Skip over the dig.Out embed. continue } rof, err := newResultObjectField(i, f, opts) if err != nil { return ro, errWrapf(err, "bad field %q of %v", f.Name, t) } ro.Fields = append(ro.Fields, rof) } return ro, nil } func (ro resultObject) Extract(cw containerWriter, v reflect.Value) { for _, f := range ro.Fields { f.Result.Extract(cw, v.Field(f.FieldIndex)) } } // resultObjectField is a single field inside a dig.Out struct. type resultObjectField struct { // Name of the field in the struct. FieldName string // Index of the field in the struct. // // We need to track this separately because not all fields of the struct // map to results. FieldIndex int // Result produced by this field. Result result } func (rof resultObjectField) DotResult() []*dot.Result { return rof.Result.DotResult() } // newResultObjectField(i, f, opts) builds a resultObjectField from the field // f at index i. func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (resultObjectField, error) { rof := resultObjectField{ FieldName: f.Name, FieldIndex: idx, } var r result switch { case f.PkgPath != "": return rof, fmt.Errorf( "unexported fields not allowed in dig.Out, did you mean to export %q (%v)?", f.Name, f.Type) case f.Tag.Get(_groupTag) != "": var err error r, err = newResultGrouped(f) if err != nil { return rof, err } default: var err error if name := f.Tag.Get(_nameTag); len(name) > 0 { // can modify in-place because options are passed-by-value. opts.Name = name } r, err = newResult(f.Type, opts) if err != nil { return rof, err } } rof.Result = r return rof, nil } // resultGrouped is a value produced by a constructor that is part of a result // group. // // These will be produced as fields of a dig.Out struct. type resultGrouped struct { // Name of the group as specified in the `group:".."` tag. Group string // Type of value produced. Type reflect.Type } func (rt resultGrouped) DotResult() []*dot.Result { return []*dot.Result{ { Node: &dot.Node{ Type: rt.Type, Group: rt.Group, }, }, } } // newResultGrouped(f) builds a new resultGrouped from the provided field. func newResultGrouped(f reflect.StructField) (resultGrouped, error) { rg := resultGrouped{Group: f.Tag.Get(_groupTag), Type: f.Type} name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) switch { case name != "": return rg, fmt.Errorf( "cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group) case optional: return rg, errors.New("value groups cannot be optional") } return rg, nil } func (rt resultGrouped) Extract(cw containerWriter, v reflect.Value) { cw.submitGroupedValue(rt.Group, rt.Type, v) } ================================================ FILE: vendor/go.uber.org/dig/stringer.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "bytes" "fmt" "strings" ) // String representation of the entire Container func (c *Container) String() string { b := &bytes.Buffer{} fmt.Fprintln(b, "nodes: {") for k, vs := range c.providers { for _, v := range vs { fmt.Fprintln(b, "\t", k, "->", v) } } fmt.Fprintln(b, "}") fmt.Fprintln(b, "values: {") for k, v := range c.values { fmt.Fprintln(b, "\t", k, "=>", v) } for k, vs := range c.groups { for _, v := range vs { fmt.Fprintln(b, "\t", k, "=>", v) } } fmt.Fprintln(b, "}") return b.String() } func (n *node) String() string { return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype) } func (k key) String() string { if k.name != "" { return fmt.Sprintf("%v[name=%q]", k.t, k.name) } if k.group != "" { return fmt.Sprintf("%v[group=%q]", k.t, k.group) } return k.t.String() } func (pl paramList) String() string { args := make([]string, len(pl.Params)) for i, p := range pl.Params { args[i] = p.String() } return fmt.Sprint(args) } func (sp paramSingle) String() string { // tally.Scope[optional] means optional // tally.Scope[optional, name="foo"] means named optional var opts []string if sp.Optional { opts = append(opts, "optional") } if sp.Name != "" { opts = append(opts, fmt.Sprintf("name=%q", sp.Name)) } if len(opts) == 0 { return fmt.Sprint(sp.Type) } return fmt.Sprintf("%v[%v]", sp.Type, strings.Join(opts, ", ")) } func (op paramObject) String() string { fields := make([]string, len(op.Fields)) for i, f := range op.Fields { fields[i] = f.Param.String() } return strings.Join(fields, " ") } func (pt paramGroupedSlice) String() string { // io.Reader[group="foo"] refers to a group of io.Readers called 'foo' return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group) } ================================================ FILE: vendor/go.uber.org/dig/types.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig import ( "container/list" "reflect" ) var ( _noValue reflect.Value _errType = reflect.TypeOf((*error)(nil)).Elem() _inPtrType = reflect.TypeOf((*In)(nil)) _inType = reflect.TypeOf(In{}) _outPtrType = reflect.TypeOf((*Out)(nil)) _outType = reflect.TypeOf(Out{}) ) // Special interface embedded inside dig sentinel values (dig.In, dig.Out) to // make their special nature obvious in the godocs. Otherwise they will appear // as plain empty structs. type digSentinel interface { digSentinel() } // In may be embedded into structs to request dig to treat them as special // parameter structs. When a constructor accepts such a struct, instead of the // struct becoming a dependency for that constructor, all its fields become // dependencies instead. See the section on Parameter Objects in the // package-level documentation for more information. // // Fields of the struct may optionally be tagged to customize the behavior of // dig. The following tags are supported, // // name Requests a value with the same name and type from the // container. See Named Values for more information. // optional If set to true, indicates that the dependency is optional and // the constructor gracefully handles its absence. // group Name of the Value Group from which this field will be filled. // The field must be a slice type. See Value Groups in the // package documentation for more information. type In struct{ digSentinel } // Out is an embeddable type that signals to dig that the returned // struct should be treated differently. Instead of the struct itself // becoming part of the container, all members of the struct will. // Out may be embedded into structs to request dig to treat them as special // result structs. When a constructor returns such a struct, instead of the // struct becoming a result of the constructor, all its fields become results // of the constructor. See the section on Result Objects in the package-level // documentation for more information. // // Fields of the struct may optionally be tagged to customize the behavior of // dig. The following tags are supported, // // name Specifies the name of the value. Only a field on a dig.In // struct with the same 'name' annotation can receive this // value. See Named Values for more information. // group Name of the Value Group to which this field's value is being // sent. See Value Groups in the package documentation for more // information. type Out struct{ digSentinel } func isError(t reflect.Type) bool { return t.Implements(_errType) } // IsIn checks whether the given struct is a dig.In struct. A struct qualifies // as a dig.In struct if it embeds the dig.In type or if any struct that it // embeds is a dig.In struct. The parameter may be the reflect.Type of the // struct rather than the struct itself. // // A struct MUST qualify as a dig.In struct for its fields to be treated // specially by dig. // // See the documentation for dig.In for a comprehensive list of supported // tags. func IsIn(o interface{}) bool { return embedsType(o, _inType) } // IsOut checks whether the given struct is a dig.Out struct. A struct // qualifies as a dig.Out struct if it embeds the dig.Out type or if any // struct that it embeds is a dig.Out struct. The parameter may be the // reflect.Type of the struct rather than the struct itself. // // A struct MUST qualify as a dig.Out struct for its fields to be treated // specially by dig. // // See the documentation for dig.Out for a comprehensive list of supported // tags. func IsOut(o interface{}) bool { return embedsType(o, _outType) } // Returns true if t embeds e or if any of the types embedded by t embed e. func embedsType(i interface{}, e reflect.Type) bool { // TODO: this function doesn't consider e being a pointer. // given `type A foo { *In }`, this function would return false for // embedding dig.In, which makes for some extra error checking in places // that call this funciton. Might be worthwhile to consider reflect.Indirect // usage to clean up the callers. if i == nil { return false } // maybe it's already a reflect.Type t, ok := i.(reflect.Type) if !ok { // take the type if it's not t = reflect.TypeOf(i) } // We are going to do a breadth-first search of all embedded fields. types := list.New() types.PushBack(t) for types.Len() > 0 { t := types.Remove(types.Front()).(reflect.Type) if t == e { return true } if t.Kind() != reflect.Struct { continue } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Anonymous { types.PushBack(f.Type) } } } // If perf is an issue, we can cache known In objects and Out objects in a // map[reflect.Type]struct{}. return false } ================================================ FILE: vendor/go.uber.org/dig/version.go ================================================ // Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package dig // Version of the library const Version = "1.6.0" ================================================ FILE: vendor/google.golang.org/appengine/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: vendor/google.golang.org/appengine/cloudsql/cloudsql.go ================================================ // Copyright 2013 Google Inc. All rights reserved. // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. /* Package cloudsql exposes access to Google Cloud SQL databases. This package does not work in App Engine "flexible environment". This package is intended for MySQL drivers to make App Engine-specific connections. Applications should use this package through database/sql: Select a pure Go MySQL driver that supports this package, and use sql.Open with protocol "cloudsql" and an address of the Cloud SQL instance. A Go MySQL driver that has been tested to work well with Cloud SQL is the go-sql-driver: import "database/sql" import _ "github.com/go-sql-driver/mysql" db, err := sql.Open("mysql", "user@cloudsql(project-id:instance-name)/dbname") Another driver that works well with Cloud SQL is the mymysql driver: import "database/sql" import _ "github.com/ziutek/mymysql/godrv" db, err := sql.Open("mymysql", "cloudsql:instance-name*dbname/user/password") Using either of these drivers, you can perform a standard SQL query. This example assumes there is a table named 'users' with columns 'first_name' and 'last_name': rows, err := db.Query("SELECT first_name, last_name FROM users") if err != nil { log.Errorf(ctx, "db.Query: %v", err) } defer rows.Close() for rows.Next() { var firstName string var lastName string if err := rows.Scan(&firstName, &lastName); err != nil { log.Errorf(ctx, "rows.Scan: %v", err) continue } log.Infof(ctx, "First: %v - Last: %v", firstName, lastName) } if err := rows.Err(); err != nil { log.Errorf(ctx, "Row error: %v", err) } */ package cloudsql import ( "net" ) // Dial connects to the named Cloud SQL instance. func Dial(instance string) (net.Conn, error) { return connect(instance) } ================================================ FILE: vendor/google.golang.org/appengine/cloudsql/cloudsql_classic.go ================================================ // Copyright 2013 Google Inc. All rights reserved. // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. // +build appengine package cloudsql import ( "net" "appengine/cloudsql" ) func connect(instance string) (net.Conn, error) { return cloudsql.Dial(instance) } ================================================ FILE: vendor/google.golang.org/appengine/cloudsql/cloudsql_vm.go ================================================ // Copyright 2013 Google Inc. All rights reserved. // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. // +build !appengine package cloudsql import ( "errors" "net" ) func connect(instance string) (net.Conn, error) { return nil, errors.New(`cloudsql: not supported in App Engine "flexible environment"`) } ================================================ FILE: vendor/vendor.json ================================================ { "comment": "", "ignore": "test", "package": [ { "path": "appengine/cloudsql", "revision": "" }, { "checksumSHA1": "AIF4lP/6rhhHA4zsdvzfM1uFXi4=", "path": "github.com/DATA-DOG/go-sqlmock", "revision": "b9ca56ce96879f5362120ae10866bbf66f2c5db6", "revisionTime": "2018-03-04T15:30:57Z" }, { "checksumSHA1": "OFu4xJEIjiI8Suu+j/gabfp+y6Q=", "origin": "github.com/stretchr/testify/vendor/github.com/davecgh/go-spew/spew", "path": "github.com/davecgh/go-spew/spew", "revision": "2aa2c176b9dab406a6970f6a55f513e8a8c8b18f", "revisionTime": "2017-08-14T20:04:35Z" }, { "checksumSHA1": "JXVlDIoOSmyi1QAgDI455Wsr/gY=", "path": "github.com/go-sql-driver/mysql", "revision": "3287d94d4c6a48a63e16fffaabf27ab20203af2a", "revisionTime": "2018-04-13T18:15:57Z" }, { "checksumSHA1": "g/V4qrXjUGG9B+e3hB+4NAYJ5Gs=", "path": "github.com/gorilla/context", "revision": "08b5f424b9271eedf6f9f0ce86cb9396ed337a42", "revisionTime": "2016-08-17T18:46:32Z" }, { "checksumSHA1": "YuYKzn2jczaM6DQcFDmukvAHUX4=", "path": "github.com/gorilla/mux", "revision": "94231ffd98496cbcb1c15b7bf2a9edfd5f852cd4", "revisionTime": "2018-04-03T18:23:30Z" }, { "checksumSHA1": "zKKp5SZ3d3ycKe4EKMNT0BqAWBw=", "origin": "github.com/stretchr/testify/vendor/github.com/pmezard/go-difflib/difflib", "path": "github.com/pmezard/go-difflib/difflib", "revision": "2aa2c176b9dab406a6970f6a55f513e8a8c8b18f", "revisionTime": "2017-08-14T20:04:35Z" }, { "checksumSHA1": "EO+jcRet/AJ6IY3lBO8l8BLsZWg=", "origin": "github.com/stretchr/testify/vendor/github.com/stretchr/objx", "path": "github.com/stretchr/objx", "revision": "2aa2c176b9dab406a6970f6a55f513e8a8c8b18f", "revisionTime": "2017-08-14T20:04:35Z" }, { "checksumSHA1": "mGbTYZ8dHVTiPTTJu3ktp+84pPI=", "path": "github.com/stretchr/testify/assert", "revision": "2aa2c176b9dab406a6970f6a55f513e8a8c8b18f", "revisionTime": "2017-08-14T20:04:35Z" }, { "checksumSHA1": "hs0IfAV4wNExbAXc0aUU9V2SuFc=", "path": "github.com/stretchr/testify/mock", "revision": "2aa2c176b9dab406a6970f6a55f513e8a8c8b18f", "revisionTime": "2017-08-14T20:04:35Z" }, { "checksumSHA1": "7vs6dSc1PPGBKyzb/SCIyeMJPLQ=", "path": "github.com/stretchr/testify/require", "revision": "2aa2c176b9dab406a6970f6a55f513e8a8c8b18f", "revisionTime": "2017-08-14T20:04:35Z" }, { "checksumSHA1": "4nm6kgOL/4Xj3j+CwKcggRg8Wno=", "path": "go.uber.org/dig", "revision": "007ab720a796b1027b6c289fb4e836c8fc077357", "revisionTime": "2018-09-19T20:28:59Z" }, { "checksumSHA1": "ukkiijCfrA/L+yTvOhmUH+5sWyI=", "path": "go.uber.org/dig/internal/digreflect", "revision": "007ab720a796b1027b6c289fb4e836c8fc077357", "revisionTime": "2018-09-19T20:28:59Z" }, { "checksumSHA1": "WURKr1FAB8025bzGumOQ/4U0rxI=", "path": "go.uber.org/dig/internal/dot", "revision": "007ab720a796b1027b6c289fb4e836c8fc077357", "revisionTime": "2018-09-19T20:28:59Z" }, { "checksumSHA1": "LiyXfqOzaeQ8vgYZH3t2hUEdVTw=", "path": "google.golang.org/appengine/cloudsql", "revision": "0a24098c0ec68416ec050f567f75df563d6b231e", "revisionTime": "2018-04-05T22:03:34Z" } ], "rootPath": "github.com/PacktPublishing/Hands-On-Dependency-Injection-in-Go" }