Repository: uber/tchannel-go Branch: dev Commit: 8f6af1a4cf92 Files: 401 Total size: 2.0 MB Directory structure: gitextract_ixnw_gpx/ ├── .github/ │ └── workflows/ │ └── tests.yaml ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── Makefile ├── README.md ├── RELEASE.md ├── all_channels.go ├── all_channels_test.go ├── arguments.go ├── arguments_test.go ├── benchmark/ │ ├── benchclient/ │ │ └── main.go │ ├── benchserver/ │ │ └── main.go │ ├── build_manager.go │ ├── client_server_bench_test.go │ ├── external_client.go │ ├── external_common.go │ ├── external_server.go │ ├── frame_templates.go │ ├── interfaces.go │ ├── internal_client.go │ ├── internal_multi_client.go │ ├── internal_server.go │ ├── internal_tcp_client.go │ ├── internal_tcp_server.go │ ├── matrix_test.go │ ├── options.go │ ├── real_relay.go │ ├── req_bytes.go │ ├── tcp_bench_test.go │ ├── tcp_frame_relay.go │ └── tcp_raw_relay.go ├── calloptions.go ├── calloptions_test.go ├── channel.go ├── channel_test.go ├── channel_utils_test.go ├── channelstate_string.go ├── checked_frame_pool.go ├── checked_frame_pool_test.go ├── checksum.go ├── close_test.go ├── codecov.yml ├── conn_leak_test.go ├── connection.go ├── connection_bench_test.go ├── connection_direction.go ├── connection_internal_test.go ├── connection_test.go ├── connectionstate_string.go ├── context.go ├── context_builder.go ├── context_header.go ├── context_internal_test.go ├── context_test.go ├── deps_test.go ├── dial_16.go ├── dial_17.go ├── dial_17_test.go ├── doc.go ├── errors.go ├── errors_test.go ├── examples/ │ ├── bench/ │ │ ├── client/ │ │ │ └── client.go │ │ ├── runner.go │ │ └── server/ │ │ └── server.go │ ├── hyperbahn/ │ │ └── echo-server/ │ │ └── main.go │ ├── hypercat/ │ │ └── main.go │ ├── keyvalue/ │ │ ├── README.md │ │ ├── client/ │ │ │ └── client.go │ │ ├── gen-go/ │ │ │ └── keyvalue/ │ │ │ ├── admin.go │ │ │ ├── baseservice.go │ │ │ ├── constants.go │ │ │ ├── keyvalue.go │ │ │ ├── tchan-keyvalue.go │ │ │ └── ttypes.go │ │ ├── keyvalue.thrift │ │ └── server/ │ │ └── server.go │ ├── ping/ │ │ ├── README.md │ │ └── main.go │ ├── test_server/ │ │ └── server.go │ └── thrift/ │ ├── example.thrift │ ├── gen-go/ │ │ └── example/ │ │ ├── base.go │ │ ├── constants.go │ │ ├── first.go │ │ ├── second.go │ │ ├── tchan-example.go │ │ └── ttypes.go │ └── main.go ├── fragmentation_test.go ├── fragmenting_reader.go ├── fragmenting_writer.go ├── frame.go ├── frame_pool.go ├── frame_pool_b_test.go ├── frame_pool_test.go ├── frame_test.go ├── frame_utils_test.go ├── go.mod ├── go.sum ├── guide/ │ └── Thrift_Hyperbahn.md ├── handlers.go ├── handlers_test.go ├── handlers_with_skip_test.go ├── health.go ├── health_ext_test.go ├── health_test.go ├── http/ │ ├── buf.go │ ├── buf_test.go │ ├── http_test.go │ ├── request.go │ └── response.go ├── hyperbahn/ │ ├── advertise.go │ ├── advertise_test.go │ ├── call.go │ ├── client.go │ ├── client_test.go │ ├── configuration.go │ ├── discover.go │ ├── discover_test.go │ ├── event_string.go │ ├── events.go │ ├── gen-go/ │ │ └── hyperbahn/ │ │ ├── constants.go │ │ ├── hyperbahn.go │ │ ├── tchan-hyperbahn.go │ │ └── ttypes.go │ ├── hyperbahn.thrift │ ├── utils.go │ └── utils_test.go ├── idle_sweep.go ├── idle_sweep_test.go ├── inbound.go ├── inbound_internal_test.go ├── inbound_test.go ├── incoming_test.go ├── init_test.go ├── internal/ │ ├── argreader/ │ │ ├── empty.go │ │ └── empty_test.go │ └── testcert/ │ └── testcert.go ├── introspection.go ├── introspection_test.go ├── json/ │ ├── call.go │ ├── context.go │ ├── handler.go │ ├── json_test.go │ ├── retry_test.go │ └── tracing_test.go ├── largereq_test.go ├── localip.go ├── localip_test.go ├── logger.go ├── logger_test.go ├── messages.go ├── messages_test.go ├── messagetype_string.go ├── mex.go ├── mex_utils_test.go ├── outbound.go ├── peer.go ├── peer_bench_test.go ├── peer_heap.go ├── peer_heap_test.go ├── peer_internal_test.go ├── peer_strategies.go ├── peer_test.go ├── peers/ │ ├── doc.go │ ├── prefer.go │ └── prefer_test.go ├── pprof/ │ ├── pprof.go │ └── pprof_test.go ├── preinit_connection.go ├── raw/ │ ├── call.go │ └── handler.go ├── relay/ │ ├── relay.go │ └── relaytest/ │ ├── func_host.go │ ├── mock_stats.go │ └── stub_host.go ├── relay.go ├── relay_api.go ├── relay_benchmark_test.go ├── relay_fragment_sender_test.go ├── relay_internal_test.go ├── relay_messages.go ├── relay_messages_benchmark_test.go ├── relay_messages_test.go ├── relay_test.go ├── relay_timer_pool.go ├── reqres.go ├── reqresreaderstate_string.go ├── reqreswriterstate_string.go ├── retry.go ├── retry_request_test.go ├── retry_test.go ├── retryon_string.go ├── root_peer_list.go ├── scripts/ │ ├── install-thrift.sh │ └── vbumper/ │ └── main.go ├── sockio_bsd.go ├── sockio_darwin.go ├── sockio_linux.go ├── sockio_non_unix.go ├── sockio_unix.go ├── stats/ │ ├── metrickey.go │ ├── metrickey_test.go │ ├── statsdreporter.go │ ├── tally.go │ └── tally_test.go ├── stats.go ├── stats_test.go ├── stats_utils_test.go ├── stream_test.go ├── stress_flag_test.go ├── subchannel.go ├── subchannel_test.go ├── systemerrcode_string.go ├── tchannel_test.go ├── testutils/ │ ├── call.go │ ├── channel.go │ ├── channel_opts.go │ ├── channel_t.go │ ├── conn.go │ ├── counter.go │ ├── counter_test.go │ ├── data.go │ ├── echo.go │ ├── goroutines/ │ │ ├── stacks.go │ │ ├── verify.go │ │ └── verify_opts.go │ ├── lists.go │ ├── logfilter_test.go │ ├── logger.go │ ├── mockhyperbahn/ │ │ ├── hyperbahn.go │ │ ├── hyperbahn_test.go │ │ └── utils.go │ ├── now.go │ ├── random_bench_test.go │ ├── relay.go │ ├── sleep.go │ ├── test_server.go │ ├── testreader/ │ │ ├── chunk.go │ │ ├── chunk_test.go │ │ ├── loop.go │ │ └── loop_test.go │ ├── testtracing/ │ │ ├── propagation.go │ │ └── propagation_test.go │ ├── testwriter/ │ │ ├── limited.go │ │ └── limited_test.go │ ├── thriftarg2test/ │ │ ├── arg2_kv.go │ │ └── arg2_kv_test.go │ ├── ticker.go │ ├── ticker_test.go │ ├── timeout.go │ └── wait.go ├── thirdparty/ │ └── github.com/ │ └── apache/ │ └── thrift/ │ └── lib/ │ └── go/ │ └── thrift/ │ ├── application_exception.go │ ├── application_exception_test.go │ ├── binary_protocol.go │ ├── binary_protocol_test.go │ ├── buffered_transport.go │ ├── buffered_transport_test.go │ ├── compact_protocol.go │ ├── compact_protocol_test.go │ ├── debug_protocol.go │ ├── deserializer.go │ ├── exception.go │ ├── exception_test.go │ ├── field.go │ ├── framed_transport.go │ ├── framed_transport_test.go │ ├── http_client.go │ ├── http_client_test.go │ ├── http_transport.go │ ├── iostream_transport.go │ ├── iostream_transport_test.go │ ├── json_protocol.go │ ├── json_protocol_test.go │ ├── lowlevel_benchmarks_test.go │ ├── memory_buffer.go │ ├── memory_buffer_test.go │ ├── messagetype.go │ ├── multiplexed_protocol.go │ ├── numeric.go │ ├── pointerize.go │ ├── processor.go │ ├── processor_factory.go │ ├── protocol.go │ ├── protocol_exception.go │ ├── protocol_factory.go │ ├── protocol_test.go │ ├── rich_transport.go │ ├── rich_transport_test.go │ ├── serializer.go │ ├── serializer_test.go │ ├── serializer_types_test.go │ ├── server.go │ ├── server_socket.go │ ├── server_socket_test.go │ ├── server_test.go │ ├── server_transport.go │ ├── simple_json_protocol.go │ ├── simple_json_protocol_test.go │ ├── simple_server.go │ ├── socket.go │ ├── ssl_server_socket.go │ ├── ssl_socket.go │ ├── transport.go │ ├── transport_exception.go │ ├── transport_exception_test.go │ ├── transport_factory.go │ ├── transport_test.go │ ├── type.go │ ├── zlib_transport.go │ └── zlib_transport_test.go ├── thrift/ │ ├── arg2/ │ │ ├── kv_iterator.go │ │ └── kv_iterator_test.go │ ├── client.go │ ├── context.go │ ├── context_test.go │ ├── doc.go │ ├── errors_test.go │ ├── gen-go/ │ │ ├── meta/ │ │ │ ├── constants.go │ │ │ ├── meta.go │ │ │ └── ttypes.go │ │ └── test/ │ │ ├── constants.go │ │ ├── meta.go │ │ ├── secondservice.go │ │ ├── simpleservice.go │ │ ├── tchan-test.go │ │ └── ttypes.go │ ├── headers.go │ ├── headers_test.go │ ├── interfaces.go │ ├── meta.go │ ├── meta.thrift │ ├── meta_test.go │ ├── mocks/ │ │ ├── TChanMeta.go │ │ ├── TChanSecondService.go │ │ └── TChanSimpleService.go │ ├── options.go │ ├── server.go │ ├── server_test.go │ ├── struct.go │ ├── struct_test.go │ ├── tchan-meta.go │ ├── test.thrift │ ├── thrift-gen/ │ │ ├── compile_test.go │ │ ├── extends.go │ │ ├── generate.go │ │ ├── gopath.go │ │ ├── gopath_test.go │ │ ├── include.go │ │ ├── main.go │ │ ├── names.go │ │ ├── tchannel-template.go │ │ ├── template.go │ │ ├── test_files/ │ │ │ ├── binary.thrift │ │ │ ├── byte.thrift │ │ │ ├── gokeywords.thrift │ │ │ ├── include_test/ │ │ │ │ ├── namespace/ │ │ │ │ │ ├── a/ │ │ │ │ │ │ └── shared.thrift │ │ │ │ │ ├── b/ │ │ │ │ │ │ └── shared.thrift │ │ │ │ │ └── namespace.thrift │ │ │ │ ├── simple/ │ │ │ │ │ ├── shared.thrift │ │ │ │ │ ├── shared2.thrift │ │ │ │ │ └── simple.thrift │ │ │ │ └── svc_extend/ │ │ │ │ ├── shared.thrift │ │ │ │ └── svc_extend.thrift │ │ │ ├── multi_test/ │ │ │ │ ├── file1.thrift │ │ │ │ └── file2.thrift │ │ │ ├── service_extend.thrift │ │ │ ├── sets.thrift │ │ │ ├── test1.thrift │ │ │ ├── typedefs.thrift │ │ │ └── union.thrift │ │ ├── typestate.go │ │ ├── validate.go │ │ └── wrap.go │ ├── thrift_bench_test.go │ ├── thrift_test.go │ ├── tracing_test.go │ ├── transport.go │ └── transport_test.go ├── tnet/ │ ├── listener.go │ └── listener_test.go ├── tos/ │ ├── tos.go │ ├── tos_string.go │ └── tos_test.go ├── trace/ │ └── doc.go ├── tracing.go ├── tracing_internal_test.go ├── tracing_keys.go ├── tracing_test.go ├── trand/ │ └── rand.go ├── typed/ │ ├── buffer.go │ ├── buffer_test.go │ ├── reader.go │ ├── reader_test.go │ ├── writer.go │ └── writer_test.go ├── utils_for_test.go ├── verify_utils_test.go └── version.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/tests.yaml ================================================ name: Tests on: push: branches: ['*'] tags: ['v*'] pull_request: branches: ['**'] jobs: test: runs-on: ubuntu-latest env: GOPATH: ${{ github.workspace }} GOBIN: ${{ github.workspace }}/bin defaults: run: working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} strategy: matrix: go: ["stable", "oldstable"] include: - go: "stable" latest: true COVERAGE: "yes" LINT: "yes" - go: "oldstable" LINT: "yes" steps: - name: Setup Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Checkout code uses: actions/checkout@v2 with: path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} - name: Load cache uses: actions/cache@v1 with: path: ~/.glide/cache key: ${{ runner.os }}-go-${{ hashFiles('**/glide.lock') }} restore-keys: | ${{ runner.os }}-go- - name: Install CI run: make install_ci - name: Test CI env: NO_TEST: ${{ matrix.NO_TEST }} run: test -n "$NO_TEST" || make test_ci - name: Cover CI env: COVERAGE: ${{ matrix.COVERAGE }} run: test -z "$COVERAGE" || make cover_ci - name: Lint CI env: LINT: ${{ matrix.LINT }} run: test -z "$LINT" || make install_lint lint - name: Crossdock CI run: make crossdock_logs_ci ================================================ FILE: .gitignore ================================================ build Godeps/_workspace vendor/ thrift-gen-release/ # Lint output lint.log # Cover profiles *.out # Editor and OS detritus *~ *.swp .DS_Store .idea tchannel-go.iml .vscode .bin/ .idea/ ================================================ FILE: CHANGELOG.md ================================================ Changelog ========= ## [1.34.6] - 2025-01-07 ### Fixed * Fix compile issue on FreeBSD 14 (#925) ## [1.34.5] - 2024-10-07 ### Changed * Add component tag to tracing spans (#923) ## [1.34.4] - 2024-06-26 ### Fixed * fix getSysConn to work with TLS (#918) ### Changed * Switch to aliases for Go versions in CI (#919) ## [1.34.3] - 2024-04-23 ### Fixed * Fix a DoS vulnerability of the vendored apache-thrift library (#915, #916) ## [1.34.2] - 2024-02-16 ### Added * Expose `inbound.cancels.{requested,honored}` metrics (#912) ## [1.34.1] - 2023-12-11 ### Fixed * Fix unknown error type in span tag rpc.tchannel.system_error_code (#907) ## [1.34.0] - 2023-10-17 ### Added * Emit the error code and type to the tracing span for inbound calls (#903) ### Changed * Update go.mod and github test to go 1.21 (#899) ## [1.33.0] - 2023-03-20 ### Added * Optionally send cancelled frames when context is canceled (#890) ### Changed * Update dependencies as per recommendations (#894) * Test against Go 1.19 and 1.20 (#892) ## [1.32.1] - 2022-08-23 ### Fixed * Release unsent frames when flushing fragments (#887) ## [1.32.0] - 2022-07-12 ### Changed * Add TLS option in testutils.NewServerChannel (#882) ## [1.31.0] - 2022-05-04 ### Changed * thrift-gen going to use vendored apache-thrift code. Currently vendored apache-thrift pinned to b2a4d4ae21c789b689dd162deb819665567f481c. ## [1.22.3] - 2022-03-28 ### Changed * Fix memory leak due to unreturned frames in the relayer. * Test against Go 1.17 and 1.18 in CI. * Migrate to go mod. ## [1.22.2] - 2021-11-10 ### Changed * Fixes related to `SkipHandlerMethods` request handling. ## [1.22.1] - 2021-10-25 ### Changed * Fixes related to `SkipHandlerMethods` request handling. ## [1.22.0] - 2021-08-13 ### Added * Add `SkipHandlerMethods` option as an allow-list of metohds that are handled by the override handler. ### Changed * Allow method registration if handler implements Register. * Internal changes related to relaying. ## [1.21.2] - 2021-05-19 ### Changed * Internal changes related to relaying. ## [1.21.1] - 2021-03-17 ### Changed * Change log level for connection create/close from info level to debug level to reduce noisy logs. ## [1.21.0] - 2020-12-13 ### Changed * Internal changes related to relaying. ## [1.20.1] - 2020-09-24 ### Fixed * Set ConnContext in the channel instead of the connection to avoid serialization errors since ConnectionOptions is sometimes embedded in configurations (#806) ## [1.20.0] - 2020-09-23 ### Added * Support per-connection base context propagation for inbound/outbound connections (#801) ### Changed * Internal API changes related to relaying. ## [1.19.1] - 2020-08-03 ### Fixed * Move OS-specific logic into OS-specific files to avoid compile issues on non-Unix platforms. ## [1.19.0] - 2020-05-21 ### Fixed * Internal API changes related to relaying. ## [1.18.0] - 2020-03-30 ### Added * Introspection now tracks last activity for reads and writes separately (#770) * Add options to allow overriding SendBufferSize per process name prefix (#772) ## [1.17.0] - 2020-02-18 ### Added * Internal API changes related to relaying. ## [1.16.0] - 2019-10-14 ### Added * Support custom Dialer for outbound connections (#759) ### Fixed * thrift: Handle TStruct serialization failures gracefully (#744) ## [1.15.0] - 2019-08-26 ### Added * introspection: Introspect any channel by ID. (#756) ### Fixed * Ensure Introspection endpoints are always available. (#755) * Fix testutils.WithTestServer incorrectly using RelayHost when creating the server. (#750) ## [1.14.0] - 2019-05-20 ### Added * Expose `CallOptions` caller name for transparent proxying (#741) ## [1.13.0] - 2019-04-04 ### Added * Add `MaxCloseTime` which sets a timeout for graceful connection close. (#724) ### Changed * Optimize Thrift string field serialization by eliminating `[]byte(string)` allocation. (#729) ### Fixed * Return an error if transport header keys/values exceed the maximum allowed string length. (#728) ## [1.12.0] - 2018-11-13 ### Added * Add a channel, `ClosedCh`, to wait for a channel to close. (#718) * Add a Code of Conduct. (#711) ### Changed * Tweak error message when sending a large error to mention that we're out of space. (#716) * Idle sweeper now skips connections that have pending calls. (#712) ## [1.11.0] - 2018-06-25 ### Added * thrift: Support health check type in Health endpoint. (#696) ## [1.10.0] - 2018-04-02 ### Added * Support blackholing requests to trigger client timeout without holding on to resources. (#681) * introspection: Include channel state in output. (#692) * introspection: Add inactive connections to output. (#686) ### Fixed * Inherit deadlines from parent context if available, and timeout is unspecified. * Ensure outbound tracing headers take precedence over application headers. (#683) ## [1.9.0] - 2018-01-31 ### Added * stats: Add tally reporter to emit tagged metrics. (#676) * Add optional idle timeout, after which connections will be closed. (#681) ## [1.8.1] - 2017-11-21 ### Fixed * Always log addresses as strings. (#669) ## [1.8.0] - 2017-11-06 ### Added * Add opt-in active connection health checks. (#318) ### Changed * Improve error logging on `thrift.Server` errors. (#663) * Reduce memory usage for idle connections. (#658) * Unpin and reduce dependencies in `glide.yaml` by using `testImports`. (#649) ### Fixed * Don't close connections on ping errors.(#655) * Avoid holding on to closed connections' memory in peers. (#644) ## [1.7.0] - 2017-08-04 ### Added * Add `WithoutHeaders` to remove TChannel keys from a context. ### Changed * Cancel the context on incoming calls if the client connection is closed. ## [1.6.0] - 2017-06-02 ### Added * Add `OnPeerStatusChanged` channel option to receive a notification each time the number of available connections changes for any given peer. ### Changed * Locks Apache Thrift to version 0.9.3, 0.10.0 to maintain backward-compatibility. * Set DiffServ (QoS) bit on outbound connections. ### Fixed * Improve resilience of the frame parser. ## [1.5.0] - 2017-03-21 ### Added * Add `PeerList.Len` to expose the number of peers in the peer list. * Add `PeerList.GetNew` to only return previously unselected peers. ## [1.4.0] - 2017-03-01 ### Added * Add version information to the channel's LocalPeerInfo. * Add peers package for peer management utilities such as consistent peer selection. ### Fixed * Fix SetScoreStrategy not rescoring existing peers. (#583). ## [1.3.0] - 2017-02-01 ### Added * Support Thrift namespaces for thrift-gen. * Exposes the channel's RootPeerList with `channel.RootPeers()`. ## [1.2.3] - 2017-01-19 ### Changed * Improve error messages when an argument reader is closed without reading the EOF. (#567) ### Fixed * thrift: Fix an issue where we return `nil` if we expected a Thrift exception but none was found (e.g., exception is from the future). (#566) * Fix ListenIP selecting docker interfaces over physical networks. (#565) * Fix for error when a Thrift payload has completed decoding and attempts to close the argument reader without waiting until EOF. (#564) * thrift-gen: Fix "namespace go" being ignored even though the Apache thrift generated code was respecting it. (#559) ## [1.2.2] - 2016-12-21 ### Added * Add a unique channel ID for introspection (#548) * Expose local peer information on {Inbound,Outbound}Call (#537) * Add remote peer info to connection logger and introspection (#514) ### Fixed * Don't drop existing headers on a context when using Wrap(ctx) (#547) * Setting response headers is not goroutine safe, allow using a child context for parallel sub-requests (#549). * Fix context cancellation not cancelling Dial attempts (#541) * Only select active connections for calls (#521) * Treat hostPorts ending in ":0" in the init headers as ephemeral (#513) ## [1.2.1] - 2016-09-29 ### Fixed * Fix data race on headers when making concurrent calls using the same context. (#505) ## [1.2.0] - 2016-09-15 ### Added * Adds support for routing keys (the TChannel rk transport header). ## [1.1.0] - 2016-08-25 ### Added * Integrate OpenTracing for distributed tracing and context propagation. As long as a Zipkin-style tracing is configured, TChannel frames still send tracing information, and `CurrentSpan(ctx)` works as before. All tracer configuration must be handled through OpenTracing. (#426) ### Changed * Improve error messages when using the json package and the host:port fails to connect. (#475) * mockhyperbahn now using inbuilt TChannel relaying to implement in-process forwarding. (#472) * Drop go1.4 support and add support for go1.7. * Pass thrift.Context to the thrift.Server's response callback (#465) ## [1.0.9] - 2016-07-20 ### Added * Expose meta endpoints on the "tchannel" service name. (#459) * Add Go version and tchannel-go library version to introspection. (#457) * Expose the number of connections on a channel. (#451) ### Changed * Better handling of peers where dialed host:port doesn't match the remote connection's reported host:port. (#452) ## [1.0.8] - 2016-07-15 ### Fixed * Remove dependency on "testing" from "tchannel-go" introduced in v1.0.7. ## [1.0.7] - 2016-07-15 ### Added * Add CallOptions() to IncomingCall which can be used as the call option when making outbound calls to proxy all transport headers. * Add tracing information to all error frames generated by the library. * Add GetHandlers for getting all registered methods on a subchannel. * Expose the peer information for outbound calls. * Support a separate connection timeout from the context timeout, useful for streaming calls where the stream timeout may be much longer than the connection timeout. ### Fixed * Fix peer score not being calculated when adding a new outbound connections ## [1.0.6] - 2016-06-16 ### Fixed * Fix trace span encoding fields in the wrong order ## [1.0.5] - 2016-04-04 ### Changed * Use `context.Context` storage for headers so `thrift.Context` and `tchannel.ContextWithHeaders` can be passed to functions that use `context.Context`, and have them retain headers. * `thrift.Server` allows a custom factory to be used for `thrift.Context` creation based on the underlying `context.Context` and headers map. * Store goroutine stack traces on channel creation that can be accessed via introspection. ## [1.0.4] - 2016-03-09 ### Added * #228: Add registered methods to the introspection output. * Add ability to set a global handler for a SubChannel. ### Fixed * Improve handling of network failures during pending calls. Previously, calls would timeout, but now they fail as soon as the network failure is detected. * Remove ephemeral peers with closed outbound connections. * #233: Ensure errors returned from Thrift handlers have a non-nil value. # 1.0.3 (2016-02-15) ### Added * Introspection now includes information about all channels created in the current process. ### Changed * Improved performance when writing Thrift structs * Make closing message exchanges less disruptive, changes a panic due to closing a channel twice to an error log. ## [1.0.2] - 2016-01-29 ### Changed * Extend the `ContextBuilder` API to support setting the transport-level routing delegate header. * Assorted logging and test improvements. ### Fixed * Set a timeout when making new outbound connections to avoid hanging. * Fix for #196: Make the initial Hyperbahn advertise more tolerant of transient timeouts. ## [1.0.1] - 2016-01-19 ### Added * Peers can now be removed using PeerList.Remove. * Add ErrorHandlerFunc to create raw handlers that return errors. * Retries try to avoid previously selected hosts, rather than just the host:port. * Create an ArgReader interface (which is an alias for io.ReadCloser) for symmetry with ArgWriter. * Add ArgReadable and ArgWritable interfaces for the common methods between calls and responses. * Expose Thrift binary encoding methods (thrift.ReadStruct, thrift.WriteStruct, thrift.ReadHeaders, thrift.WriteHeaders) so callers can easily send Thrift payloads over the streaming interface. ### Fixed * Bug fix for #181: Shuffle peers on PeerList.Add to avoid biases in peer selection. ## 1.0.0 - 2016-01-11 ### Added * First stable release. * Support making calls with JSON, Thrift or raw payloads. * Services use thrift-gen, and implement handlers with a `func(ctx, arg) (res, error)` signature. * Support retries. * Peer selection (peer heap, prefer incoming strategy, for use with Hyperbahn). * Graceful channel shutdown. * TCollector trace reporter with sampling support. * Metrics collection with StatsD. * Thrift support, including includes. [//]: # (Version Links) [1.34.6]: https://github.com/uber/tchannel-go/compare/v1.34.5...v1.34.6 [1.34.5]: https://github.com/uber/tchannel-go/compare/v1.34.4...v1.34.5 [1.34.4]: https://github.com/uber/tchannel-go/compare/v1.34.3...v1.34.4 [1.34.3]: https://github.com/uber/tchannel-go/compare/v1.34.2...v1.34.3 [1.34.2]: https://github.com/uber/tchannel-go/compare/v1.34.1...v1.34.2 [1.34.1]: https://github.com/uber/tchannel-go/compare/v1.34.0...v1.34.1 [1.34.0]: https://github.com/uber/tchannel-go/compare/v1.33.0...v1.34.0 [1.33.0]: https://github.com/uber/tchannel-go/compare/v1.32.1...v1.33.0 [1.32.1]: https://github.com/uber/tchannel-go/compare/v1.32.0...v1.32.1 [1.32.0]: https://github.com/uber/tchannel-go/compare/v1.31.0...v1.32.0 [1.31.0]: https://github.com/uber/tchannel-go/compare/v1.22.3...v1.31.0 [1.22.3]: https://github.com/uber/tchannel-go/compare/v1.22.2...v1.22.3 [1.22.1]: https://github.com/uber/tchannel-go/compare/v1.22.0...v1.22.1 [1.22.0]: https://github.com/uber/tchannel-go/compare/v1.21.2...v1.22.0 [1.21.2]: https://github.com/uber/tchannel-go/compare/v1.21.1...v1.21.2 [1.21.1]: https://github.com/uber/tchannel-go/compare/v1.21.0...v1.21.1 [1.21.0]: https://github.com/uber/tchannel-go/compare/v1.20.1...v1.21.0 [1.20.1]: https://github.com/uber/tchannel-go/compare/v1.20.0...v1.20.1 [1.20.0]: https://github.com/uber/tchannel-go/compare/v1.19.1...v1.20.0 [1.18.0]: https://github.com/uber/tchannel-go/compare/v1.17.0...v1.18.0 [1.17.0]: https://github.com/uber/tchannel-go/compare/v1.16.0...v1.17.0 [1.16.0]: https://github.com/uber/tchannel-go/compare/v1.15.0...v1.16.0 [1.15.0]: https://github.com/uber/tchannel-go/compare/v1.14.0...v1.15.0 [1.14.0]: https://github.com/uber/tchannel-go/compare/v1.13.0...v1.14.0 [1.13.0]: https://github.com/uber/tchannel-go/compare/v1.12.0...v1.13.0 [1.12.0]: https://github.com/uber/tchannel-go/compare/v1.11.0...v1.12.0 [1.11.0]: https://github.com/uber/tchannel-go/compare/v1.10.0...v1.11.0 [1.10.0]: https://github.com/uber/tchannel-go/compare/v1.9.0...v1.10.0 [1.9.0]: https://github.com/uber/tchannel-go/compare/v1.8.1...v1.9.0 [1.8.1]: https://github.com/uber/tchannel-go/compare/v1.8.0...v1.8.1 [1.8.0]: https://github.com/uber/tchannel-go/compare/v1.7.0...v1.8.0 [1.7.0]: https://github.com/uber/tchannel-go/compare/v1.6.0...v1.7.0 [1.6.0]: https://github.com/uber/tchannel-go/compare/v1.5.0...v1.6.0 [1.5.0]: https://github.com/uber/tchannel-go/compare/v1.4.0...v1.5.0 [1.4.0]: https://github.com/uber/tchannel-go/compare/v1.3.0...v1.4.0 [1.3.0]: https://github.com/uber/tchannel-go/compare/v1.2.3...v1.3.0 [1.2.3]: https://github.com/uber/tchannel-go/compare/v1.2.2...v1.2.3 [1.2.2]: https://github.com/uber/tchannel-go/compare/v1.2.1...v1.2.2 [1.2.1]: https://github.com/uber/tchannel-go/compare/v1.2.0...v1.2.1 [1.2.0]: https://github.com/uber/tchannel-go/compare/v1.1.0...v1.2.0 [1.1.0]: https://github.com/uber/tchannel-go/compare/v1.0.9...v1.1.0 [1.0.9]: https://github.com/uber/tchannel-go/compare/v1.0.8...v1.0.9 [1.0.8]: https://github.com/uber/tchannel-go/compare/v1.0.7...v1.0.8 [1.0.7]: https://github.com/uber/tchannel-go/compare/v1.0.6...v1.0.7 [1.0.6]: https://github.com/uber/tchannel-go/compare/v1.0.5...v1.0.6 [1.0.5]: https://github.com/uber/tchannel-go/compare/v1.0.4...v1.0.5 [1.0.4]: https://github.com/uber/tchannel-go/compare/v1.0.2...v1.0.4 [1.0.2]: https://github.com/uber/tchannel-go/compare/v1.0.1...v1.0.2 [1.0.1]: https://github.com/uber/tchannel-go/compare/v1.0.0...v1.0.1 ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at oss-conduct@uber.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version]. [homepage]: http://contributor-covenant.org [version]: http://contributor-covenant.org/version/1/4/ ================================================ FILE: CONTRIBUTING.md ================================================ Contributing ============ We'd love your help making tchannel-go great! ## Getting Started TChannel uses [glide](https://github.com/Masterminds/glide) to manage dependencies. To get started: ```bash go get github.com/uber/tchannel-go make install_glide make # tests should pass ``` ## Making A Change *Before making any significant changes, please [open an issue](https://github.com/uber/tchannel-go/issues).* Discussing your proposed changes ahead of time will make the contribution process smooth for everyone. Once we've discussed your changes and you've got your code ready, make sure that tests are passing (`make test` or `make cover`) and open your PR! Your pull request is most likely to be accepted if it: * Includes tests for new functionality. * Follows the guidelines in [Effective Go](https://golang.org/doc/effective_go.html) and the [Go team's common code review comments](https://github.com/golang/go/wiki/CodeReviewComments). * Has a [good commit message](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). ## Cutting a Release * Send a pull request against dev including: * update CHANGELOG.md (`scripts/changelog_halp.sh`) * update version.go * Send a pull request for dev into master * `git tag -m v0.0.0 -a v0.0.0` * `git push origin --tags` * Copy CHANGELOG.md fragment into release notes on https://github.com/uber/tchannel-go/releases ================================================ FILE: LICENSE.md ================================================ Copyright (c) 2015 Uber Technologies, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ PATH := $(GOPATH)/bin:$(PATH) EXAMPLES=./examples/bench/server ./examples/bench/client ./examples/ping ./examples/thrift ./examples/hyperbahn/echo-server ALL_PKGS := $(shell go list ./... | grep -v 'thirdparty') PROD_PKGS := . ./http ./hyperbahn ./json ./peers ./pprof ./raw ./relay ./stats ./thrift $(EXAMPLES) TEST_ARG ?= -race -v -timeout 10m COV_PKG ?= ./ BUILD := ./build THRIFT_GEN_RELEASE := ./thrift-gen-release THRIFT_GEN_RELEASE_LINUX := $(THRIFT_GEN_RELEASE)/linux-x86_64 THRIFT_GEN_RELEASE_DARWIN := $(THRIFT_GEN_RELEASE)/darwin-x86_64 PLATFORM := $(shell uname -s | tr '[:upper:]' '[:lower:]') ARCH := $(shell uname -m) BIN := $(shell pwd)/.bin # Cross language test args TEST_HOST=127.0.0.1 TEST_PORT=0 -include crossdock/rules.mk all: lint test examples $(BIN)/thrift: mkdir -p $(BIN) scripts/install-thrift.sh $(BIN) packages_test: go list -json ./... | jq -r '. | select ((.TestGoFiles | length) > 0) | .ImportPath' setup: mkdir -p $(BUILD) mkdir -p $(BUILD)/examples mkdir -p $(THRIFT_GEN_RELEASE_LINUX) mkdir -p $(THRIFT_GEN_RELEASE_DARWIN) install: go mod vendor install_lint: @echo "Installing golint, since we expect to lint" go install golang.org/x/lint/golint@latest install_ci: $(BIN)/thrift install ifdef CROSSDOCK $(MAKE) install_docker_ci endif help: @egrep "^# target:" [Mm]akefile | sort - clean: echo Cleaning build artifacts... go clean rm -rf $(BUILD) $(THRIFT_GEN_RELEASE) echo fmt format: echo Formatting Packages... go fmt $(ALL_PKGS) echo test_ci: ifdef CROSSDOCK $(MAKE) crossdock_ci else $(MAKE) test endif test: clean setup check_no_test_deps $(BIN)/thrift $(MAKE) test_vanilla $(MAKE) test_relay_frame_leaks # test_vanilla runs all unit tests without checking for frame leaks test_vanilla: @echo Testing packages: PATH=$(BIN):$$PATH DISABLE_FRAME_POOLING_CHECKS=1 go test -parallel=4 $(TEST_ARG) $(ALL_PKGS) @echo Running frame pool tests PATH=$(BIN):$$PATH go test -run TestFramesReleased -stressTest $(TEST_ARG) # test_relay_frame_leaks runs unit tests in relay_test.go with frame leak checks enabled test_relay_frame_leaks: @echo Testing relay frame leaks PATH=$(BIN):$$PATH go test -parallel=4 $(TEST_ARG) relay_test.go check_no_test_deps: ! go list -json $(PROD_PKGS) | jq -r '.Deps | select ((. | length) > 0) | .[]' | grep -e test -e mock | grep -v '^internal/testlog' benchmark: clean setup $(BIN)/thrift echo Running benchmarks: PATH=$(BIN)::$$PATH go test $(ALL_PKGS) -bench=. -cpu=1 -benchmem -run NONE cover_profile: clean setup $(BIN)/thrift @echo Testing packages: mkdir -p $(BUILD) PATH=$(BIN)::$$PATH DISABLE_FRAME_POOLING_CHECKS=1 go test $(COV_PKG) $(TEST_ARG) -coverprofile=$(BUILD)/coverage.out cover: cover_profile go tool cover -html=$(BUILD)/coverage.out cover_ci: @echo "Uploading coverage" $(MAKE) cover_profile curl -s https://codecov.io/bash > $(BUILD)/codecov.bash bash $(BUILD)/codecov.bash -f $(BUILD)/coverage.out FILTER := grep -v -e '_string.go' -e '/gen-go/' -e '/mocks/' -e 'vendor/' -e 'thirdparty' lint: install @echo "Running golint" -golint $(ALL_PKGS) | $(FILTER) | tee lint.log @echo "Running go vet" -go vet $(ALL_PKGS) 2>&1 | $(FILTER) | fgrep -v -e "possible formatting directiv" -e "exit status" | tee -a lint.log @echo "Verifying files are gofmt'd" -gofmt -l . | $(FILTER) | tee -a lint.log @echo "Checking for unresolved FIXMEs" -git grep -i -n fixme | $(FILTER) | grep -v -e Makefile | tee -a lint.log @[ ! -s lint.log ] thrift_example: thrift_gen go build -o $(BUILD)/examples/thrift ./examples/thrift/main.go test_server: ./build/examples/test_server --host ${TEST_HOST} --port ${TEST_PORT} examples: clean setup thrift_example echo Building examples... mkdir -p $(BUILD)/examples/ping $(BUILD)/examples/bench go build -o $(BUILD)/examples/ping/pong ./examples/ping/main.go go build -o $(BUILD)/examples/hyperbahn/echo-server ./examples/hyperbahn/echo-server/main.go go build -o $(BUILD)/examples/bench/server ./examples/bench/server go build -o $(BUILD)/examples/bench/client ./examples/bench/client go build -o $(BUILD)/examples/bench/runner ./examples/bench/runner.go go build -o $(BUILD)/examples/test_server ./examples/test_server thrift_gen: $(BIN)/thrift go build -o $(BUILD)/thrift-gen ./thrift/thrift-gen PATH=$(BIN):$$PATH $(BUILD)/thrift-gen --generateThrift --inputFile thrift/test.thrift --outputDir thrift/gen-go/ PATH=$(BIN):$$PATH $(BUILD)/thrift-gen --generateThrift --inputFile examples/keyvalue/keyvalue.thrift --outputDir examples/keyvalue/gen-go PATH=$(BIN):$$PATH $(BUILD)/thrift-gen --generateThrift --inputFile examples/thrift/example.thrift --outputDir examples/thrift/gen-go PATH=$(BIN):$$PATH $(BUILD)/thrift-gen --generateThrift --inputFile hyperbahn/hyperbahn.thrift --outputDir hyperbahn/gen-go release_thrift_gen: clean setup GOOS=linux GOARCH=amd64 go build -o $(THRIFT_GEN_RELEASE_LINUX)/thrift-gen ./thrift/thrift-gen GOOS=darwin GOARCH=amd64 go build -o $(THRIFT_GEN_RELEASE_DARWIN)/thrift-gen ./thrift/thrift-gen tar -czf thrift-gen-release.tar.gz $(THRIFT_GEN_RELEASE) mv thrift-gen-release.tar.gz $(THRIFT_GEN_RELEASE)/ .PHONY: all help clean fmt format install install_ci install_lint release_thrift_gen packages_test check_no_test_deps test test_ci lint .SILENT: all help clean fmt format test lint ================================================ FILE: README.md ================================================ # TChannel [![GoDoc][doc-img]][doc] [![Build Status][ci-img]][ci] [![Coverage Status][cov-img]][cov] [TChannel][tchan-spec] is a multiplexing and framing protocol for RPC calls. tchannel-go is a Go implementation of the protocol, including client libraries for [Hyperbahn][hyperbahn]. If you'd like to start by writing a small Thrift and TChannel service, check out [this guide](guide/Thrift_Hyperbahn.md). For a less opinionated setup, see the [contribution guidelines](CONTRIBUTING.md). ## Overview TChannel is a network protocol that supports: * A request/response model, * Multiplexing multiple requests across the same TCP socket, * Out-of-order responses, * Streaming requests and responses, * Checksummed frames, * Transport of arbitrary payloads, * Easy implementation in many languages, and * Redis-like performance. This protocol is intended to run on datacenter networks for inter-process communication. ## Protocol TChannel frames have a fixed-length header and 3 variable-length fields. The underlying protocol does not assign meaning to these fields, but the included client/server implementation uses the first field to represent a unique endpoint or function name in an RPC model. The next two fields can be used for arbitrary data. Some suggested way to use the 3 fields are: * URI path + HTTP method and headers as JSON + body, or * Function name + headers + thrift/protobuf. Note, however, that the only encoding supported by TChannel is UTF-8. If you want JSON, you'll need to stringify and parse outside of TChannel. This design supports efficient routing and forwarding: routers need to parse the first or second field, but can forward the third field without parsing. There is no notion of client and server in this system. Every TChannel instance is capable of making and receiving requests, and thus requires a unique port on which to listen. This requirement may change in the future. See the [protocol specification][tchan-proto-spec] for more details. ## Examples - [ping](examples/ping): A simple ping/pong example using raw TChannel. - [thrift](examples/thrift): A Thrift server/client example. - [keyvalue](examples/keyvalue): A keyvalue Thrift service with separate server and client binaries.
This project is released under the [MIT License](LICENSE.md). [doc-img]: https://godoc.org/github.com/uber/tchannel-go?status.svg [doc]: https://godoc.org/github.com/uber/tchannel-go [ci-img]: https://github.com/uber/tchannel-go/actions/workflows/tests.yaml/badge.svg?branch=master [ci]: https://github.com/uber/tchannel-go/actions/workflows/tests.yaml [cov-img]: https://coveralls.io/repos/uber/tchannel-go/badge.svg?branch=master&service=github [cov]: https://coveralls.io/github/uber/tchannel-go?branch=master [tchan-spec]: http://tchannel.readthedocs.org/en/latest/ [tchan-proto-spec]: http://tchannel.readthedocs.org/en/latest/protocol/ [hyperbahn]: https://github.com/uber/hyperbahn ================================================ FILE: RELEASE.md ================================================ Release process =============== This document outlines how to create a release of tchannel-go 1. Set up some environment variables for use later. ``` # This is the version being released. $ VERSION=1.8.0 ``` 2. Make sure you have the latest dev and create a branch off it. ``` $ git checkout dev $ git pull $ git checkout -b release ``` 3. Update the `CHANGELOG.md` and `version.go` files. ``` $ go run ./scripts/vbumper/main.go --version $VERSION ``` 4. Clean up the `CHANGELOG.md` to only mention noteworthy changes for users. 5. Commit changes and create a PR against `dev` to prepare for release. 6. Once the release PR has been accepted, run the following to release. ``` $ git checkout master $ git pull $ git merge dev $ git tag -a "v$VERSION" -m "v$VERSION" $ git push origin master v$VERSION ``` 7. Go to and edit the release notes. Copy changelog entries for this release and set the name to `v$VERSION`. 8. Switch back to development. ``` $ git checkout dev $ git merge master $ go run ./scripts/vbumper/main.go --version ${VERSION}-dev --skip-changelog $ git commit -am "Back to development" $ git push ``` ================================================ FILE: all_channels.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "sync" ) // channelMap is used to ensure that applications don't create multiple channels with // the same service name in a single process. var channelMap = struct { sync.Mutex existing map[string][]*Channel }{ existing: make(map[string][]*Channel), } func registerNewChannel(ch *Channel) { serviceName := ch.ServiceName() ch.createdStack = string(getStacks(false /* all */)) ch.log.WithFields( LogField{"channelPtr", fmt.Sprintf("%p", ch)}, LogField{"createdStack", ch.createdStack}, ).Info("Created new channel.") channelMap.Lock() defer channelMap.Unlock() existing := channelMap.existing[serviceName] channelMap.existing[serviceName] = append(existing, ch) } func removeClosedChannel(ch *Channel) { channelMap.Lock() defer channelMap.Unlock() channels := channelMap.existing[ch.ServiceName()] for i, v := range channels { if v != ch { continue } // Replace current index with the last element, and truncate channels. channels[i] = channels[len(channels)-1] channels = channels[:len(channels)-1] break } channelMap.existing[ch.ServiceName()] = channels } func findChannelByID(id uint32) (*Channel, bool) { channelMap.Lock() defer channelMap.Unlock() for _, channels := range channelMap.existing { for _, ch := range channels { if ch.chID == id { return ch, true } } } return nil, false } ================================================ FILE: all_channels_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestAllChannelsRegistered(t *testing.T) { introspectOpts := &IntrospectionOptions{IncludeOtherChannels: true} ch1_1, err := NewChannel("ch1", nil) require.NoError(t, err, "Channel create failed") ch1_2, err := NewChannel("ch1", nil) require.NoError(t, err, "Channel create failed") ch2_1, err := NewChannel("ch2", nil) require.NoError(t, err, "Channel create failed") state := ch1_1.IntrospectState(introspectOpts) assert.Equal(t, 1, len(state.OtherChannels["ch1"])) assert.Equal(t, 1, len(state.OtherChannels["ch2"])) ch1_2.Close() state = ch1_1.IntrospectState(introspectOpts) assert.Equal(t, 0, len(state.OtherChannels["ch1"])) assert.Equal(t, 1, len(state.OtherChannels["ch2"])) ch2_2, err := NewChannel("ch2", nil) state = ch1_1.IntrospectState(introspectOpts) require.NoError(t, err, "Channel create failed") assert.Equal(t, 0, len(state.OtherChannels["ch1"])) assert.Equal(t, 2, len(state.OtherChannels["ch2"])) ch1_1.Close() ch2_1.Close() ch2_2.Close() state = ch1_1.IntrospectState(introspectOpts) assert.Equal(t, 0, len(state.OtherChannels["ch1"])) assert.Equal(t, 0, len(state.OtherChannels["ch2"])) } ================================================ FILE: arguments.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bufio" "encoding/json" "io" "io/ioutil" "github.com/uber/tchannel-go/internal/argreader" ) // ArgReader is the interface for the arg2 and arg3 streams on an // OutboundCallResponse and an InboundCall type ArgReader io.ReadCloser // ArgWriter is the interface for the arg2 and arg3 streams on an OutboundCall // and an InboundCallResponse type ArgWriter interface { io.WriteCloser // Flush flushes the currently written bytes without waiting for the frame // to be filled. Flush() error } // ArgWritable is an interface for providing arg2 and arg3 writer streams; // implemented by reqResWriter e.g. OutboundCall and InboundCallResponse type ArgWritable interface { Arg2Writer() (ArgWriter, error) Arg3Writer() (ArgWriter, error) } // ArgReadable is an interface for providing arg2 and arg3 reader streams; // implemented by reqResReader e.g. InboundCall and OutboundCallResponse. type ArgReadable interface { Arg2Reader() (ArgReader, error) Arg3Reader() (ArgReader, error) } // ArgReadHelper providers a simpler interface to reading arguments. type ArgReadHelper struct { reader ArgReader err error } // NewArgReader wraps the result of calling ArgXReader to provide a simpler // interface for reading arguments. func NewArgReader(reader ArgReader, err error) ArgReadHelper { return ArgReadHelper{reader, err} } func (r ArgReadHelper) read(f func() error) error { if r.err != nil { return r.err } if err := f(); err != nil { return err } if err := argreader.EnsureEmpty(r.reader, "read arg"); err != nil { return err } return r.reader.Close() } // Read reads from the reader into the byte slice. func (r ArgReadHelper) Read(bs *[]byte) error { return r.read(func() error { var err error *bs, err = ioutil.ReadAll(r.reader) return err }) } // ReadJSON deserializes JSON from the underlying reader into data. func (r ArgReadHelper) ReadJSON(data interface{}) error { return r.read(func() error { // TChannel allows for 0 length values (not valid JSON), so we use a bufio.Reader // to check whether data is of 0 length. reader := bufio.NewReader(r.reader) if _, err := reader.Peek(1); err == io.EOF { // If the data is 0 length, then we don't try to read anything. return nil } else if err != nil { return err } d := json.NewDecoder(reader) return d.Decode(data) }) } // ArgWriteHelper providers a simpler interface to writing arguments. type ArgWriteHelper struct { writer io.WriteCloser err error } // NewArgWriter wraps the result of calling ArgXWriter to provider a simpler // interface for writing arguments. func NewArgWriter(writer io.WriteCloser, err error) ArgWriteHelper { return ArgWriteHelper{writer, err} } func (w ArgWriteHelper) write(f func() error) error { if w.err != nil { return w.err } if err := f(); err != nil { return err } return w.writer.Close() } // Write writes the given bytes to the underlying writer. func (w ArgWriteHelper) Write(bs []byte) error { return w.write(func() error { _, err := w.writer.Write(bs) return err }) } // WriteJSON writes the given object as JSON. func (w ArgWriteHelper) WriteJSON(data interface{}) error { return w.write(func() error { e := json.NewEncoder(w.writer) return e.Encode(data) }) } ================================================ FILE: arguments_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "io" "io/ioutil" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type bufferWithClose struct { *bytes.Buffer closed bool } var _ io.WriteCloser = &bufferWithClose{} var _ io.ReadCloser = &bufferWithClose{} func newWriter() *bufferWithClose { return &bufferWithClose{bytes.NewBuffer(nil), false} } func newReader(bs []byte) *bufferWithClose { return &bufferWithClose{bytes.NewBuffer(bs), false} } func (w *bufferWithClose) Close() error { w.closed = true return nil } type testObject struct { Name string `json:"name"` Value int `json:"value"` } func TestJSONInputOutput(t *testing.T) { obj := testObject{Name: "Foo", Value: 20756} writer := newWriter() require.Nil(t, NewArgWriter(writer, nil).WriteJSON(obj)) assert.True(t, writer.closed) assert.Equal(t, "{\"name\":\"Foo\",\"value\":20756}\n", writer.String()) reader := newReader(writer.Bytes()) outObj := testObject{} require.Nil(t, NewArgReader(reader, nil).ReadJSON(&outObj)) assert.True(t, reader.closed) assert.Equal(t, "Foo", outObj.Name) assert.Equal(t, 20756, outObj.Value) } func TestReadNotEmpty(t *testing.T) { // Note: The contents need to be larger than the default buffer size of bufio.NewReader. r := bytes.NewReader([]byte("{}" + strings.Repeat("{}\n", 10000))) var data map[string]interface{} reader := NewArgReader(ioutil.NopCloser(r), nil) require.Error(t, reader.ReadJSON(&data), "Read should fail due to extra bytes") } func BenchmarkArgReaderWriter(b *testing.B) { obj := testObject{Name: "Foo", Value: 20756} outObj := testObject{} for i := 0; i < b.N; i++ { writer := newWriter() NewArgWriter(writer, nil).WriteJSON(obj) reader := newReader(writer.Bytes()) NewArgReader(reader, nil).ReadJSON(&outObj) } b.StopTimer() assert.Equal(b, obj, outObj) } ================================================ FILE: benchmark/benchclient/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // benchclient is used to make requests to a specific server. package main import ( "bufio" "flag" "fmt" "log" "os" "strconv" "strings" "time" "github.com/uber/tchannel-go/benchmark" ) var ( serviceName = flag.String("service", "bench-server", "The benchmark server's service name") timeout = flag.Duration("timeout", time.Second, "Timeout for each request") requestSize = flag.Int("request-size", 10000, "The number of bytes of each request") noLibrary = flag.Bool("no-library", false, "Whether to use the template based library instead of TChannel's client library") numClients = flag.Int("num-clients", 1, "Number of concurrent clients to run in process") noDurations = flag.Bool("no-durations", false, "Disable printing of latencies to stdout") ) func main() { flag.Parse() opts := []benchmark.Option{ benchmark.WithServiceName(*serviceName), benchmark.WithTimeout(*timeout), benchmark.WithRequestSize(*requestSize), benchmark.WithNumClients(*numClients), } if *noLibrary { opts = append(opts, benchmark.WithNoLibrary()) } client := benchmark.NewClient(flag.Args(), opts...) fmt.Println("bench-client started") rdr := bufio.NewScanner(os.Stdin) for rdr.Scan() { line := rdr.Text() parts := strings.Split(line, " ") var n int var err error if len(parts) >= 2 { n, err = strconv.Atoi(parts[1]) if err != nil { log.Fatalf("unrecognized number %q: %v", parts[1], err) } } switch cmd := parts[0]; cmd { case "warmup": if err := client.Warmup(); err != nil { log.Fatalf("warmup failed: %v", err) } fmt.Println("success") continue case "rcall": makeCalls(n, client.RawCall) case "tcall": makeCalls(n, client.ThriftCall) case "quit": return default: log.Fatalf("unrecognized command: %v", line) } } if err := rdr.Err(); err != nil { log.Fatalf("Reader failed: %v", err) } } func makeCalls(n int, f func(n int) ([]time.Duration, error)) { durations, err := f(n) if err != nil { log.Fatalf("Call failed: %v", err) } if !*noDurations { for i, d := range durations { if i > 0 { fmt.Printf(" ") } fmt.Printf("%v", d) } } fmt.Println() } ================================================ FILE: benchmark/benchserver/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // benchserver is used to receive requests for benchmarks. package main import ( "bufio" "flag" "fmt" "io" "log" "os" "strings" "github.com/uber/tchannel-go/benchmark" ) var ( serviceName = flag.String("service", "bench-server", "The benchmark server's service name") advertiseHosts = flag.String("advertise-hosts", "", "Comma-separated list of hosts to advertise to") ) func main() { flag.Parse() var adHosts []string if len(*advertiseHosts) > 0 { adHosts = strings.Split(*advertiseHosts, ",") } server := benchmark.NewServer( benchmark.WithServiceName(*serviceName), benchmark.WithAdvertiseHosts(adHosts), ) fmt.Println(server.HostPort()) rdr := bufio.NewReader(os.Stdin) for { line, err := rdr.ReadString('\n') if err != nil { if err == io.EOF { return } log.Fatalf("stdin read failed: %v", err) } line = strings.TrimSuffix(line, "\n") switch line { case "count-raw": fmt.Println(server.RawCalls()) case "count-thrift": fmt.Println(server.ThriftCalls()) case "quit": return default: log.Fatalf("unrecognized command: %v", line) } } } ================================================ FILE: benchmark/build_manager.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "io/ioutil" "os" "os/exec" "sync" ) type buildManager struct { sync.RWMutex builds map[string]*build } type build struct { once sync.Once mainFile string binaryFile string buildErr error } func newBuildManager() *buildManager { return &buildManager{ builds: make(map[string]*build), } } func (m *buildManager) GoBinary(mainFile string) (string, error) { m.Lock() bld, ok := m.builds[mainFile] if !ok { bld = &build{mainFile: mainFile} m.builds[mainFile] = bld } m.Unlock() bld.once.Do(bld.Build) return bld.binaryFile, bld.buildErr } func (b *build) Build() { tempFile, err := ioutil.TempFile("", "bench") if err != nil { panic("Failed to create temp file: " + err.Error()) } tempFile.Close() buildCmd := exec.Command("go", "build", "-o", tempFile.Name(), b.mainFile) buildCmd.Stdout = os.Stdout buildCmd.Stderr = os.Stderr if err := buildCmd.Run(); err != nil { b.buildErr = err return } b.binaryFile = tempFile.Name() } ================================================ FILE: benchmark/client_server_bench_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "log" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func BenchmarkServer(b *testing.B) { server := NewServer() client := NewClient([]string{server.HostPort()}, WithExternalProcess(), WithNoLibrary(), WithNumClients(10), WithNoDurations(), WithTimeout(10*time.Second), ) assert.NoError(b, client.Warmup(), "Warmup failed") b.ResetTimer() started := time.Now() _, err := client.RawCall(b.N) total := time.Since(started) assert.NoError(b, err, "client.RawCall failed") if n := server.RawCalls(); b.N > n { b.Errorf("Server received %v calls, expected at least %v calls", n, b.N) } log.Printf("Calls: %v Duration: %v RPS: %.0f", b.N, total, float64(b.N)/total.Seconds()) } func BenchmarkClient(b *testing.B) { servers := make([]Server, 3) serverHosts := make([]string, len(servers)) for i := range servers { servers[i] = NewServer( WithExternalProcess(), WithNoLibrary(), ) serverHosts[i] = servers[i].HostPort() } // To saturate a single process, we need to have multiple clients. client := NewClient(serverHosts, WithNoChecking(), WithNumClients(10), ) require.NoError(b, client.Warmup(), "Warmup failed") b.ResetTimer() started := time.Now() if _, err := client.RawCall(b.N); err != nil { b.Fatalf("Call failed: %v", err) } total := time.Since(started) log.Printf("Calls: %v Duration: %v RPS: %.0f", b.N, total, float64(b.N)/total.Seconds()) } ================================================ FILE: benchmark/external_client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "fmt" "strconv" "strings" "time" ) // externalClient represents a benchmark client running out-of-process. type externalClient struct { *externalCmd opts *options } func newExternalClient(hosts []string, opts *options) Client { benchArgs := []string{ "--service", opts.svcName, "--timeout", opts.timeout.String(), "--request-size", strconv.Itoa(opts.reqSize), "--num-clients", strconv.Itoa(opts.numClients), } if opts.noDurations { benchArgs = append(benchArgs, "--no-durations") } if opts.noLibrary { benchArgs = append(benchArgs, "--no-library") } benchArgs = append(benchArgs, hosts...) cmd, initial := newExternalCmd("benchclient/main.go", benchArgs) if !strings.Contains(initial, "started") { panic("bench-client did not start, got: " + initial) } return &externalClient{cmd, opts} } func (c *externalClient) Warmup() error { out, err := c.writeAndRead("warmup") if err != nil { return err } if out != "success" { return fmt.Errorf("warmup failed: %v", out) } return nil } func (c *externalClient) callAndParse(cmd string) ([]time.Duration, error) { out, err := c.writeAndRead(cmd) if err != nil { return nil, err } if out == "" { return nil, nil } durationStrs := strings.Split(out, " ") durations := make([]time.Duration, len(durationStrs)) for i, s := range durationStrs { d, err := time.ParseDuration(s) if err != nil { return nil, fmt.Errorf("calls failed: %v", out) } durations[i] = d } return durations, nil } func (c *externalClient) RawCall(n int) ([]time.Duration, error) { return c.callAndParse(fmt.Sprintf("rcall %v", n)) } func (c *externalClient) ThriftCall(n int) ([]time.Duration, error) { return c.callAndParse(fmt.Sprintf("tcall %v", n)) } ================================================ FILE: benchmark/external_common.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "bufio" "io" "os" "os/exec" ) var _bm = newBuildManager() // externalCmd handles communication with an external benchmark client. type externalCmd struct { cmd *exec.Cmd stdoutOrig io.ReadCloser stdout *bufio.Scanner stdin io.WriteCloser } func newExternalCmd(mainFile string, benchArgs []string) (*externalCmd, string) { bin, err := _bm.GoBinary(BenchmarkDir + mainFile) if err != nil { panic("failed to compile " + mainFile + ": " + err.Error()) } cmd := exec.Command(bin, benchArgs...) cmd.Stderr = os.Stderr stdout, err := cmd.StdoutPipe() if err != nil { panic("failed to create stdout: " + err.Error()) } stdin, err := cmd.StdinPipe() if err != nil { panic("failed to create stdin: " + err.Error()) } if err := cmd.Start(); err != nil { panic("failed to start external process: " + err.Error()) } stdoutScanner := bufio.NewScanner(stdout) if !stdoutScanner.Scan() { panic("failed to check if external process started: " + err.Error()) } out := stdoutScanner.Text() return &externalCmd{ cmd: cmd, stdin: stdin, stdout: stdoutScanner, stdoutOrig: stdout, }, out } func (c *externalCmd) writeAndRead(cmd string) (string, error) { if _, err := io.WriteString(c.stdin, cmd+"\n"); err != nil { return "", err } if c.stdout.Scan() { return c.stdout.Text(), nil } return "", c.stdout.Err() } func (c *externalCmd) Close() { c.stdin.Close() c.stdoutOrig.Close() c.cmd.Process.Kill() } ================================================ FILE: benchmark/external_server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "net" "strconv" "strings" ) // externalServer represents a benchmark server running out-of-process. type externalServer struct { *externalCmd hostPort string opts *options } func newExternalServer(opts *options) Server { benchArgs := []string{ "--service", opts.svcName, } if len(opts.advertiseHosts) > 0 { benchArgs = append(benchArgs, "--advertise-hosts", strings.Join(opts.advertiseHosts, ",")) } cmd, hostPortStr := newExternalCmd("benchserver/main.go", benchArgs) if _, _, err := net.SplitHostPort(hostPortStr); err != nil { panic("bench-server did not print host:port on startup: " + err.Error()) } return &externalServer{cmd, hostPortStr, opts} } func (s *externalServer) HostPort() string { return s.hostPort } func (s *externalServer) RawCalls() int { return s.writeAndReadInt("count-raw") } func (s *externalServer) ThriftCalls() int { return s.writeAndReadInt("count-thrift") } func (s *externalServer) writeAndReadInt(cmd string) int { v, err := s.writeAndRead(cmd) if err != nil { panic(err) } vInt, err := strconv.Atoi(v) if err != nil { panic(err) } return vInt } ================================================ FILE: benchmark/frame_templates.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "bytes" "encoding/binary" "io" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" ) const ( _idOffset = 4 /* size (2) + type (1) + reserved (1) */ _idOffsetEnd = _idOffset + 4 /* length */ ) type frames struct { outgoing [][]byte incoming [][]byte } func (f frames) duplicate() frames { return frames{ outgoing: deepCopyByteSlice(f.outgoing), incoming: deepCopyByteSlice(f.incoming), } } func deepCopyByteSlice(bs [][]byte) [][]byte { newBs := make([][]byte, len(bs)) for i, b := range bs { newBs[i] = make([]byte, len(b)) copy(newBs[i], b) } return newBs } func (f frames) writeInitReq(w io.Writer) error { _, err := w.Write(f.outgoing[0]) return err } func (f frames) writeInitRes(w io.Writer) error { _, err := w.Write(f.incoming[0]) return err } func (f frames) writeCallReq(id uint32, w io.Writer) (int, error) { frames := f.outgoing[1:] return f.writeMulti(id, w, frames) } func (f frames) writeCallRes(id uint32, w io.Writer) (int, error) { frames := f.incoming[1:] return f.writeMulti(id, w, frames) } func (f frames) writeMulti(id uint32, w io.Writer, frames [][]byte) (int, error) { written := 0 for _, f := range frames { binary.BigEndian.PutUint32(f[_idOffset:_idOffsetEnd], id) if _, err := w.Write(f); err != nil { return written, err } written++ } return written, nil } func getRawCallFrames(timeout time.Duration, svcName string, reqSize int) frames { var fs frames modifier := func(fromClient bool, f *tchannel.Frame) *tchannel.Frame { buf := &bytes.Buffer{} if err := f.WriteOut(buf); err != nil { panic(err) } if fromClient { fs.outgoing = append(fs.outgoing, buf.Bytes()) } else { fs.incoming = append(fs.incoming, buf.Bytes()) } return f } withNewServerClient(svcName, func(server, client *tchannel.Channel) { testutils.RegisterEcho(server, nil) relay, err := NewTCPFrameRelay([]string{server.PeerInfo().HostPort}, modifier) if err != nil { panic(err) } defer relay.Close() args := &raw.Args{ Arg2: getRequestBytes(reqSize), Arg3: getRequestBytes(reqSize), } ctx, cancel := tchannel.NewContext(timeout) defer cancel() if _, _, _, err := raw.Call(ctx, client, relay.HostPort(), svcName, "echo", args.Arg2, args.Arg3); err != nil { panic(err) } }) return fs } func withNewServerClient(svcName string, f func(server, client *tchannel.Channel)) { opts := testutils.NewOpts().SetServiceName(svcName) server, err := testutils.NewServerChannel(opts) if err != nil { panic(err) } defer server.Close() client, err := testutils.NewClientChannel(opts) if err != nil { panic(err) } defer client.Close() f(server, client) } ================================================ FILE: benchmark/interfaces.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import "time" // BenchmarkDir should be set to the benchmark source directory. var BenchmarkDir = "./" // Client is a benchmark client that can be used to call a benchmark server. type Client interface { // Warmup will create connections to all host:ports the client was created with. Warmup() error // RawCall makes an echo call using raw. RawCall(n int) ([]time.Duration, error) // ThriftCall makes an echo call using thrift. ThriftCall(n int) ([]time.Duration, error) // Close closes the benchmark client. Close() } // inProcClient represents a client that is running in the same process. // It adds methods to reduce allocations. type inProcClient interface { Client // RawCallBuffer will make n raw calls and store the latencies in the specified buffer. RawCallBuffer(latencies []time.Duration) error // ThriftCallBuffer will make n thrift calls and store the latencies in the specified buffer. ThriftCallBuffer(latencies []time.Duration) error } // Server is a benchmark server that can receive requests. type Server interface { // HostPort returns the HostPort that the server is listening on. HostPort() string // Close closes the benchmark server. Close() // RawCalls returns the number of raw calls the server has received. RawCalls() int // ThriftCalls returns the number of Thrift calls the server has received. ThriftCalls() int } // Relay represents a relay for benchmarking. type Relay interface { // HostPort is the host:port that the relay is listening on. HostPort() string // Close clsoes the relay. Close() } ================================================ FILE: benchmark/internal_client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "bytes" "fmt" "os" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/thrift" gen "github.com/uber/tchannel-go/thrift/gen-go/test" ) // internalClient represents a benchmark client. type internalClient struct { ch *tchannel.Channel sc *tchannel.SubChannel tClient gen.TChanSecondService argStr string argBytes []byte checkResult bool opts *options } // NewClient returns a new Client that can make calls to a benchmark server. func NewClient(hosts []string, optFns ...Option) Client { opts := getOptions(optFns) if opts.external { return newExternalClient(hosts, opts) } if opts.numClients > 1 { return newInternalMultiClient(hosts, opts) } return newClient(hosts, opts) } func newClient(hosts []string, opts *options) inProcClient { if opts.external || opts.numClients > 1 { panic("newClient got options that should be handled by NewClient") } if opts.noLibrary { return newInternalTCPClient(hosts, opts) } return newInternalClient(hosts, opts) } func newInternalClient(hosts []string, opts *options) inProcClient { ch, err := tchannel.NewChannel(opts.svcName, &tchannel.ChannelOptions{ Logger: tchannel.NewLevelLogger(tchannel.NewLogger(os.Stderr), tchannel.LogLevelWarn), }) if err != nil { panic("failed to create channel: " + err.Error()) } for _, host := range hosts { ch.Peers().Add(host) } thriftClient := thrift.NewClient(ch, opts.svcName, nil) client := gen.NewTChanSecondServiceClient(thriftClient) return &internalClient{ ch: ch, sc: ch.GetSubChannel(opts.svcName), tClient: client, argBytes: getRequestBytes(opts.reqSize), argStr: getRequestString(opts.reqSize), opts: opts, } } func (c *internalClient) Warmup() error { for _, peer := range c.ch.Peers().Copy() { ctx, cancel := tchannel.NewContext(c.opts.timeout) _, err := peer.GetConnection(ctx) cancel() if err != nil { return err } } return nil } func (c *internalClient) makeCalls(latencies []time.Duration, f func() (time.Duration, error)) error { for i := range latencies { var err error latencies[i], err = f() if err != nil { return err } } return nil } func (c *internalClient) RawCallBuffer(latencies []time.Duration) error { return c.makeCalls(latencies, func() (time.Duration, error) { ctx, cancel := tchannel.NewContext(c.opts.timeout) defer cancel() started := time.Now() rArg2, rArg3, _, err := raw.CallSC(ctx, c.sc, "echo", c.argBytes, c.argBytes) duration := time.Since(started) if err != nil { return 0, err } if c.checkResult { if !bytes.Equal(rArg2, c.argBytes) || !bytes.Equal(rArg3, c.argBytes) { fmt.Println("Arg2", rArg2, "Expect", c.argBytes) fmt.Println("Arg3", rArg3, "Expect", c.argBytes) panic("echo call returned wrong results") } } return duration, nil }) } func (c *internalClient) RawCall(n int) ([]time.Duration, error) { latencies := make([]time.Duration, n) return latencies, c.RawCallBuffer(latencies) } func (c *internalClient) ThriftCallBuffer(latencies []time.Duration) error { return c.makeCalls(latencies, func() (time.Duration, error) { ctx, cancel := thrift.NewContext(c.opts.timeout) defer cancel() started := time.Now() res, err := c.tClient.Echo(ctx, c.argStr) duration := time.Since(started) if err != nil { return 0, err } if c.checkResult { if res != c.argStr { panic("thrift Echo returned wrong result") } } return duration, nil }) } func (c *internalClient) ThriftCall(n int) ([]time.Duration, error) { latencies := make([]time.Duration, n) return latencies, c.ThriftCallBuffer(latencies) } func (c *internalClient) Close() { c.ch.Close() } ================================================ FILE: benchmark/internal_multi_client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "time" "github.com/uber/tchannel-go/testutils" ) type internalMultiClient struct { clients []inProcClient } func newInternalMultiClient(hosts []string, opts *options) Client { clients := make([]inProcClient, opts.numClients) opts.numClients = 1 for i := range clients { clients[i] = newClient(hosts, opts) } return &internalMultiClient{clients: clients} } func (c *internalMultiClient) Warmup() error { for _, c := range c.clients { if err := c.Warmup(); err != nil { return err } } return nil } func (c *internalMultiClient) Close() { for _, client := range c.clients { client.Close() } } func (c *internalMultiClient) RawCall(n int) ([]time.Duration, error) { return c.makeCalls(n, func(c inProcClient) callFunc { return c.RawCallBuffer }) } func (c *internalMultiClient) ThriftCall(n int) ([]time.Duration, error) { return c.makeCalls(n, func(c inProcClient) callFunc { return c.ThriftCallBuffer }) } type callFunc func([]time.Duration) error type clientToCallFunc func(c inProcClient) callFunc func (c *internalMultiClient) makeCalls(n int, f clientToCallFunc) ([]time.Duration, error) { buckets := testutils.Buckets(n, len(c.clients)) errCs := make([]chan error, len(c.clients)) var start int latencies := make([]time.Duration, n) for i := range c.clients { calls := buckets[i] end := start + calls errCs[i] = c.callUsingClient(latencies[start:end], f(c.clients[i])) start = end } for _, errC := range errCs { if err := <-errC; err != nil { return nil, err } } return latencies, nil } func (c *internalMultiClient) callUsingClient(latencies []time.Duration, f callFunc) chan error { errC := make(chan error, 1) if len(latencies) == 0 { errC <- nil return errC } go func() { errC <- f(latencies) }() return errC } ================================================ FILE: benchmark/internal_server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "fmt" "os" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/hyperbahn" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/thrift" gen "github.com/uber/tchannel-go/thrift/gen-go/test" "go.uber.org/atomic" "golang.org/x/net/context" ) // internalServer represents a benchmark server. type internalServer struct { ch *tchannel.Channel hc *hyperbahn.Client opts *options rawCalls atomic.Int64 thriftCalls atomic.Int64 } // NewServer returns a new Server that can recieve Thrift calls or raw calls. func NewServer(optFns ...Option) Server { opts := getOptions(optFns) if opts.external { return newExternalServer(opts) } ch, err := tchannel.NewChannel(opts.svcName, &tchannel.ChannelOptions{ Logger: tchannel.NewLevelLogger(tchannel.NewLogger(os.Stderr), tchannel.LogLevelWarn), }) if err != nil { panic("failed to create channel: " + err.Error()) } if err := ch.ListenAndServe("127.0.0.1:0"); err != nil { panic("failed to listen on port 0: " + err.Error()) } s := &internalServer{ ch: ch, opts: opts, } tServer := thrift.NewServer(ch) tServer.Register(gen.NewTChanSecondServiceServer(handler{calls: &s.thriftCalls})) ch.Register(raw.Wrap(rawHandler{calls: &s.rawCalls}), "echo") if len(opts.advertiseHosts) > 0 { if err := s.Advertise(opts.advertiseHosts); err != nil { panic("failed to advertise: " + err.Error()) } } return s } // HostPort returns the host:port that the server is listening on. func (s *internalServer) HostPort() string { return s.ch.PeerInfo().HostPort } // Advertise advertises with Hyperbahn. func (s *internalServer) Advertise(hyperbahnHosts []string) error { var err error config := hyperbahn.Configuration{InitialNodes: hyperbahnHosts} s.hc, err = hyperbahn.NewClient(s.ch, config, nil) if err != nil { panic("failed to setup Hyperbahn client: " + err.Error()) } return s.hc.Advertise() } func (s *internalServer) Close() { s.ch.Close() if s.hc != nil { s.hc.Close() } } func (s *internalServer) RawCalls() int { return int(s.rawCalls.Load()) } func (s *internalServer) ThriftCalls() int { return int(s.thriftCalls.Load()) } type rawHandler struct { calls *atomic.Int64 } func (rawHandler) OnError(ctx context.Context, err error) { fmt.Println("benchmark.Server error:", err) } func (h rawHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { h.calls.Inc() return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, }, nil } type handler struct { calls *atomic.Int64 } func (h handler) Echo(ctx thrift.Context, arg1 string) (string, error) { h.calls.Inc() return arg1, nil } ================================================ FILE: benchmark/internal_tcp_client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "fmt" "math/rand" "net" "time" "github.com/uber/tchannel-go" ) // internalTCPClient represents a TCP client that makes // TChannel calls using raw TCP packets. type internalTCPClient struct { host string lastID uint32 responseIDs chan uint32 conn net.Conn frames frames opts *options } func newInternalTCPClient(hosts []string, opts *options) inProcClient { return &internalTCPClient{ host: hosts[rand.Intn(len(hosts))], responseIDs: make(chan uint32, 1000), frames: getRawCallFrames(opts.timeout, opts.svcName, opts.reqSize), lastID: 1, opts: opts, } } func (c *internalTCPClient) Warmup() error { conn, err := net.Dial("tcp", c.host) if err != nil { return err } c.conn = conn go c.readConn() if err := c.frames.writeInitReq(conn); err != nil { panic(err) } return nil } func (c *internalTCPClient) readConn() { defer close(c.responseIDs) wantFirstID := true f := tchannel.NewFrame(tchannel.MaxFrameSize) for { err := f.ReadIn(c.conn) if err != nil { return } if wantFirstID { if f.Header.ID != 1 { panic(fmt.Errorf("Expected first response ID to be 1, got %v", f.Header.ID)) } wantFirstID = false continue } c.responseIDs <- f.Header.ID } } type call struct { id uint32 started time.Time numFrames int } func (c *internalTCPClient) makeCalls(latencies []time.Duration, f func() (call, error)) error { n := len(latencies) calls := make(map[uint32]*call, n) for i := 0; i < n; i++ { c, err := f() if err != nil { return err } calls[c.id] = &c } timer := time.NewTimer(c.opts.timeout) // Use the original underlying slice for latencies. durations := latencies[:0] for { if len(calls) == 0 { return nil } timer.Reset(c.opts.timeout) select { case id, ok := <-c.responseIDs: if !ok { panic("expecting more calls, but connection is closed") } call, ok := calls[id] if !ok { panic(fmt.Errorf("received unexpected response frame: %v", id)) } call.numFrames-- if call.numFrames != 0 { continue } durations = append(durations, time.Since(call.started)) delete(calls, id) case <-timer.C: return tchannel.ErrTimeout } } } func (c *internalTCPClient) RawCallBuffer(latencies []time.Duration) error { return c.makeCalls(latencies, func() (call, error) { c.lastID++ started := time.Now() numFrames, err := c.frames.writeCallReq(c.lastID, c.conn) if err != nil { return call{}, err } return call{c.lastID, started, numFrames}, nil }) } func (c *internalTCPClient) RawCall(n int) ([]time.Duration, error) { latencies := make([]time.Duration, n) return latencies, c.RawCallBuffer(latencies) } func (c *internalTCPClient) ThriftCallBuffer(latencies []time.Duration) error { panic("not yet implemented") } func (c *internalTCPClient) ThriftCall(n int) ([]time.Duration, error) { panic("not yet implemented") } func (c *internalTCPClient) Close() { c.conn.Close() } ================================================ FILE: benchmark/internal_tcp_server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "log" "net" "github.com/uber/tchannel-go" "go.uber.org/atomic" ) // internalTCPServer represents a TCP server responds to TChannel // calls using raw TCP packets. type internalTCPServer struct { frames frames ln net.Listener opts *options rawCalls atomic.Int64 } func newInternalTCPServer(opts *options) Server { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { panic(err) } s := &internalTCPServer{ ln: ln, frames: getRawCallFrames(opts.timeout, opts.svcName, opts.reqSize), opts: opts, } go s.acceptLoop() return s } func (s *internalTCPServer) acceptLoop() { for { conn, err := s.ln.Accept() if err, ok := err.(net.Error); ok && err.Temporary() { continue } if err != nil { return } go s.handleConn(conn) } } func (s *internalTCPServer) handleConn(conn net.Conn) { c := make(chan uint32, 1000) defer close(c) go s.writeResponses(conn, c) var lastID uint32 f := tchannel.NewFrame(tchannel.MaxFrameSize) for { if err := f.ReadIn(conn); err != nil { return } if f.Header.ID > lastID { c <- f.Header.ID lastID = f.Header.ID } } } func (s *internalTCPServer) writeResponses(conn net.Conn, ids chan uint32) { frames := s.frames.duplicate() for id := range ids { if id == 1 { if err := frames.writeInitRes(conn); err != nil { log.Printf("writeInitRes failed: %v", err) } continue } s.rawCalls.Inc() if _, err := frames.writeCallRes(id, conn); err != nil { log.Printf("writeCallRes failed: %v", err) return } } } func (s *internalTCPServer) HostPort() string { return s.ln.Addr().String() } func (s *internalTCPServer) RawCalls() int { return int(s.rawCalls.Load()) } func (s *internalTCPServer) ThriftCalls() int { // Server does not support Thrift calls currently. return 0 } func (s *internalTCPServer) Close() { s.ln.Close() } ================================================ FILE: benchmark/matrix_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "fmt" "testing" "github.com/uber/tchannel-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // combinations will call f with every combination of selecting elements // from a slice with the specified length. // e.g. for 2, the callback would be: // f(false, false) // f(false, true) // f(true, false) // f(true, true) func combinations(length int, f func([]bool)) { cur := make([]bool, length) toGenerate := (1 << uint(length)) f(cur) for i := 0; i < toGenerate-1; i++ { var digit int for digit = length - 1; cur[digit]; digit-- { cur[digit] = false } cur[digit] = true f(cur) } } func TestCombinations(t *testing.T) { tests := []struct { length int want [][]bool }{ { length: 1, want: [][]bool{{false}, {true}}, }, { length: 2, want: [][]bool{{false, false}, {false, true}, {true, false}, {true, true}}, }, } for _, tt := range tests { var got [][]bool recordCombs := func(comb []bool) { copied := append([]bool(nil), comb...) got = append(got, copied) } combinations(tt.length, recordCombs) assert.Equal(t, tt.want, got, "Mismatch for combinations of length %v", tt.length) } } func selectOptions(options []Option, toSelect []bool) []Option { var opts []Option for i, v := range toSelect { if v { opts = append(opts, options[i]) } } return opts } func combineOpts(base, override []Option) []Option { resultOpts := append([]Option(nil), base...) return append(resultOpts, override...) } func runSingleTest(t *testing.T, baseOpts, serverOpts, clientOpts []Option) { serverOpts = combineOpts(baseOpts, serverOpts) clientOpts = combineOpts(baseOpts, clientOpts) msgP := fmt.Sprintf("%+v: ", struct { serverOpts options clientOpts options }{*(getOptions(serverOpts)), *(getOptions(clientOpts))}) server := NewServer(serverOpts...) defer server.Close() client := NewClient([]string{server.HostPort()}, clientOpts...) defer client.Close() require.NoError(t, client.Warmup(), msgP+"Client warmup failed") durations, err := client.RawCall(0) require.NoError(t, err, msgP+"Call(0) failed") assert.Equal(t, 0, len(durations), msgP+"Wrong number of calls") assert.Equal(t, 0, server.RawCalls(), msgP+"server.RawCalls mismatch") assert.Equal(t, 0, server.ThriftCalls(), msgP+"server.ThriftCalls mismatch") expectCalls := 0 for i := 1; i < 10; i *= 2 { durations, err = client.RawCall(i) require.NoError(t, err, msgP+"Call(%v) failed", i) require.Equal(t, i, len(durations), msgP+"Wrong number of calls") expectCalls += i require.Equal(t, expectCalls, server.RawCalls(), msgP+"server.RawCalls mismatch") require.Equal(t, 0, server.ThriftCalls(), msgP+"server.ThriftCalls mismatch") } } func TestServerClientMatrix(t *testing.T) { tests := [][]Option{ {WithServiceName("other")}, {WithRequestSize(tchannel.MaxFrameSize)}, } // These options can be independently applied to the server or the client. independentOpts := []Option{ WithExternalProcess(), WithNoLibrary(), } // These options only apply to the client. clientOnlyOpts := combineOpts(independentOpts, []Option{ WithNumClients(5), }) for _, tt := range tests { combinations(len(independentOpts), func(serverSelect []bool) { combinations(len(clientOnlyOpts), func(clientSelect []bool) { serverOpts := selectOptions(independentOpts, serverSelect) clientOpts := selectOptions(clientOnlyOpts, clientSelect) runSingleTest(t, tt, serverOpts, clientOpts) }) }) } } ================================================ FILE: benchmark/options.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import "time" type options struct { external bool svcName string noLibrary bool // Following options only make sense for clients. noChecking bool timeout time.Duration reqSize int numClients int // noDurations disables printing of durations to stdout. // This only applies to clients running out-of-process. noDurations bool // Following options only make sense for servers. advertiseHosts []string } // Option represents a Benchmark option. type Option func(*options) // WithTimeout sets the timeout to use for each call. func WithTimeout(timeout time.Duration) Option { return func(opts *options) { opts.timeout = timeout } } // WithRequestSize sets the request size for each call. func WithRequestSize(reqSize int) Option { return func(opts *options) { opts.reqSize = reqSize } } // WithServiceName sets the service name of the benchmark server. func WithServiceName(svcName string) Option { return func(opts *options) { opts.svcName = svcName } } // WithExternalProcess creates a separate process to host the server/client. func WithExternalProcess() Option { return func(opts *options) { opts.external = true } } // WithNoLibrary uses the fast TCP-template based approach for generating // TChannel frames rather than the TChannel client library. func WithNoLibrary() Option { return func(opts *options) { opts.noLibrary = true } } // WithNoChecking disables result verification on the client side, which // may slow down the client (as it compares all request bytes against the // response bytes). func WithNoChecking() Option { return func(opts *options) { opts.noChecking = true } } // WithNumClients sets the number of concurrent TChannel clients to use // internally under a single benchmark.Client. This is used to generate // generate a large amount of traffic, as a single TChannel client will // not saturate a CPU since it will spend most of the time blocking and // waiting for the remote side to respond. func WithNumClients(numClients int) Option { return func(opts *options) { opts.numClients = numClients } } // WithNoDurations disables printing of latencies to standard out. func WithNoDurations() Option { return func(opts *options) { opts.noDurations = true } } // WithAdvertiseHosts sets the hosts to advertise with on startup. func WithAdvertiseHosts(hosts []string) Option { return func(opts *options) { opts.advertiseHosts = hosts } } func getOptions(optFns []Option) *options { opts := &options{ timeout: time.Second, svcName: "bench-server", } for _, opt := range optFns { opt(opts) } return opts } ================================================ FILE: benchmark/real_relay.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "errors" "os" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/relay/relaytest" "go.uber.org/atomic" ) type fixedHosts struct { hosts map[string][]string appends []relay.KeyVal pickI atomic.Int32 } func (fh *fixedHosts) Get(cf relay.CallFrame, _ *relay.Conn) (string, error) { peers := fh.hosts[string(cf.Service())] if len(peers) == 0 { return "", errors.New("no peers") } for _, kv := range fh.appends { cf.Arg2Append(kv.Key, kv.Val) } pickI := int(fh.pickI.Inc()-1) % len(peers) return peers[pickI], nil } type realRelay struct { ch *tchannel.Channel hosts *fixedHosts } // NewRealRelay creates a TChannel relay. func NewRealRelay(services map[string][]string, appends []relay.KeyVal) (Relay, error) { hosts := &fixedHosts{ hosts: services, appends: appends, } ch, err := tchannel.NewChannel("relay", &tchannel.ChannelOptions{ RelayHost: relaytest.HostFunc(hosts.Get), Logger: tchannel.NewLevelLogger(tchannel.NewLogger(os.Stderr), tchannel.LogLevelWarn), }) if err != nil { return nil, err } if err := ch.ListenAndServe("127.0.0.1:0"); err != nil { return nil, err } return &realRelay{ ch: ch, hosts: hosts, }, nil } func (r *realRelay) HostPort() string { return r.ch.PeerInfo().HostPort } func (r *realRelay) Close() { r.ch.Close() } ================================================ FILE: benchmark/req_bytes.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark func getRequestBytes(n int) []byte { bs := make([]byte, n) for i := range bs { bs[i] = byte(i) } return bs } func getRequestString(n int) string { // TODO: we should replace this with base64 once we drop go1.4 support. chars := []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcedefghijklmnopqrstuvwxyz") bs := make([]byte, n) for i := range bs { bs[i] = chars[i%len(chars)] } return string(bs) } ================================================ FILE: benchmark/tcp_bench_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "io" "io/ioutil" "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func echoServer(tb testing.TB) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(tb, err, "Listen failed") go func() { conn, err := ln.Accept() require.NoError(tb, err, "Accept failed") // Echo the connection back to itself. io.Copy(conn, conn) }() return ln } func benchmarkClient(b *testing.B, dst string, reqSize int) { req := getRequestBytes(reqSize) totalExpected := b.N * reqSize conn, err := net.Dial("tcp", dst) require.NoError(b, err, "Failed to connect to destination") defer conn.Close() readerDone := make(chan struct{}) go func() { defer close(readerDone) n, err := io.CopyN(ioutil.Discard, conn, int64(totalExpected)) assert.NoError(b, err, "Expected %v response bytes, got %v", totalExpected, n) }() b.SetBytes(int64(reqSize)) for i := 0; i < b.N; i++ { _, err := conn.Write(req) require.NoError(b, err, "Write failed") } <-readerDone } func benchmarkTCPDirect(b *testing.B, reqSize int) { ln := echoServer(b) benchmarkClient(b, ln.Addr().String(), reqSize) } func BenchmarkTCPDirect100Bytes(b *testing.B) { benchmarkTCPDirect(b, 100) } func BenchmarkTCPDirect1k(b *testing.B) { benchmarkTCPDirect(b, 1024) } func BenchmarkTCPDirect4k(b *testing.B) { benchmarkTCPDirect(b, 4*1024) } func benchmarkTCPRelay(b *testing.B, reqSize int) { ln := echoServer(b) relay, err := NewTCPRawRelay([]string{ln.Addr().String()}) require.NoError(b, err, "Relay failed") defer relay.Close() benchmarkClient(b, relay.HostPort(), reqSize) } func BenchmarkTCPRelay100Bytes(b *testing.B) { benchmarkTCPRelay(b, 100) } func BenchmarkTCPRelay1kBytes(b *testing.B) { benchmarkTCPRelay(b, 1024) } func BenchmarkTCPRelay4k(b *testing.B) { benchmarkTCPRelay(b, 4*1024) } ================================================ FILE: benchmark/tcp_frame_relay.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "log" "net" ) import "github.com/uber/tchannel-go" type tcpFrameRelay struct { *tcpRelay modifier func(bool, *tchannel.Frame) *tchannel.Frame } // NewTCPFrameRelay relays frames from one connection to another. It reads // and writes frames using the TChannel frame functions. func NewTCPFrameRelay(dests []string, modifier func(bool, *tchannel.Frame) *tchannel.Frame) (Relay, error) { var err error r := &tcpFrameRelay{modifier: modifier} r.tcpRelay, err = newTCPRelay(dests, r.handleConnFrameRelay) if err != nil { return nil, err } return r, nil } func (r *tcpFrameRelay) handleConnFrameRelay(fromClient bool, src, dst net.Conn) { pool := tchannel.NewSyncFramePool() frameCh := make(chan *tchannel.Frame, 100) defer close(frameCh) go func() { for f := range frameCh { if err := f.WriteOut(dst); err != nil { log.Printf("Failed to write out frame: %v", err) return } pool.Release(f) } }() for { f := pool.Get() if err := f.ReadIn(src); err != nil { return } if r.modifier != nil { f = r.modifier(fromClient, f) } select { case frameCh <- f: default: panic("frame buffer full") } } } ================================================ FILE: benchmark/tcp_raw_relay.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package benchmark import ( "io" "log" "net" "go.uber.org/atomic" ) type tcpRelay struct { destI atomic.Int32 dests []string ln net.Listener handleConn func(fromClient bool, src, dst net.Conn) } func newTCPRelay(dests []string, handleConn func(fromClient bool, src, dst net.Conn)) (*tcpRelay, error) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err } relay := &tcpRelay{ dests: dests, ln: ln, handleConn: handleConn, } go relay.acceptLoop() return relay, nil } // NewTCPRawRelay creates a relay that just pipes data from one connection // to another directly. func NewTCPRawRelay(dests []string) (Relay, error) { return newTCPRelay(dests, func(_ bool, src, dst net.Conn) { io.Copy(src, dst) }) } func (r *tcpRelay) acceptLoop() { for { conn, err := r.ln.Accept() if err, ok := err.(net.Error); ok && err.Temporary() { continue } if err != nil { return } go r.handleIncoming(conn) } } func (r *tcpRelay) handleIncoming(src net.Conn) { defer src.Close() dst, err := net.Dial("tcp", r.nextDestination()) if err != nil { log.Printf("Connection failed: %v", err) return } defer dst.Close() go r.handleConn(true, src, dst) r.handleConn(false, dst, src) } func (r *tcpRelay) nextDestination() string { i := int(r.destI.Inc()-1) % len(r.dests) return r.dests[i] } func (r *tcpRelay) HostPort() string { return r.ln.Addr().String() } func (r *tcpRelay) Close() { r.ln.Close() } ================================================ FILE: calloptions.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel // Format is the arg scheme used for a specific call. type Format string // The list of formats supported by tchannel. const ( HTTP Format = "http" JSON Format = "json" Raw Format = "raw" Thrift Format = "thrift" ) func (f Format) String() string { return string(f) } // CallOptions are options for a specific call. type CallOptions struct { // Format is arg scheme used for this call, sent in the "as" header. // This header is only set if the Format is set. Format Format // ShardKey determines where this call request belongs, used with ringpop applications. ShardKey string // RequestState stores request state across retry attempts. RequestState *RequestState // RoutingKey identifies the destined traffic group. Relays may favor the // routing key over the service name to route the request to a specialized // traffic group. RoutingKey string // RoutingDelegate identifies a traffic group capable of routing a request // to an instance of the intended service. RoutingDelegate string // CallerName defaults to the channel's service name for an outbound call. // Optionally override this field to support transparent proxying when inbound // caller names vary across calls. CallerName string } var defaultCallOptions = &CallOptions{} func (c *CallOptions) setHeaders(headers transportHeaders) { headers[ArgScheme] = Raw.String() c.overrideHeaders(headers) } // overrideHeaders sets headers if the call options contains non-default values. func (c *CallOptions) overrideHeaders(headers transportHeaders) { if c.Format != "" { headers[ArgScheme] = c.Format.String() } if c.ShardKey != "" { headers[ShardKey] = c.ShardKey } if c.RoutingKey != "" { headers[RoutingKey] = c.RoutingKey } if c.RoutingDelegate != "" { headers[RoutingDelegate] = c.RoutingDelegate } if c.CallerName != "" { headers[CallerName] = c.CallerName } } // setResponseHeaders copies some headers from the incoming call request to the response. func setResponseHeaders(reqHeaders, respHeaders transportHeaders) { respHeaders[ArgScheme] = reqHeaders[ArgScheme] } ================================================ FILE: calloptions_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "testing" "github.com/stretchr/testify/assert" ) func TestSetHeaders(t *testing.T) { tests := []struct { format Format routingDelegate string routingKey string callerName string expectedHeaders transportHeaders }{ { // When no format is specified, Raw should be used by default. format: "", expectedHeaders: transportHeaders{ArgScheme: Raw.String()}, }, { format: Thrift, expectedHeaders: transportHeaders{ArgScheme: Thrift.String()}, }, { callerName: "foo-caller", expectedHeaders: transportHeaders{ ArgScheme: Raw.String(), CallerName: "foo-caller", }, }, { format: JSON, routingDelegate: "xpr", expectedHeaders: transportHeaders{ ArgScheme: JSON.String(), RoutingDelegate: "xpr", }, }, { format: JSON, routingKey: "canary", expectedHeaders: transportHeaders{ ArgScheme: JSON.String(), RoutingKey: "canary", }, }, } for _, tt := range tests { callOpts := &CallOptions{ Format: tt.format, RoutingDelegate: tt.routingDelegate, RoutingKey: tt.routingKey, CallerName: tt.callerName, } headers := make(transportHeaders) callOpts.setHeaders(headers) assert.Equal(t, tt.expectedHeaders, headers) } } ================================================ FILE: channel.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "errors" "fmt" "net" "os" "path/filepath" "runtime" "strings" "sync" "time" "github.com/uber/tchannel-go/tnet" "github.com/opentracing/opentracing-go" "go.uber.org/atomic" "golang.org/x/net/context" ) var ( errAlreadyListening = errors.New("channel already listening") errInvalidStateForOp = errors.New("channel is in an invalid state for that method") errMaxIdleTimeNotSet = errors.New("IdleCheckInterval is set but MaxIdleTime is zero") // ErrNoServiceName is returned when no service name is provided when // creating a new channel. ErrNoServiceName = errors.New("no service name provided") ) const ephemeralHostPort = "0.0.0.0:0" // ChannelOptions are used to control parameters on a create a TChannel type ChannelOptions struct { // Default Connection options DefaultConnectionOptions ConnectionOptions // The name of the process, for logging and reporting to peers ProcessName string // OnPeerStatusChanged is an optional callback that receives a notification // whenever the channel establishes a usable connection to a peer, or loses // a connection to a peer. OnPeerStatusChanged func(*Peer) // The logger to use for this channel Logger Logger // The host:port selection implementation to use for relaying. This is an // unstable API - breaking changes are likely. RelayHost RelayHost // The list of service names that should be handled locally by this channel. // This is an unstable API - breaking changes are likely. RelayLocalHandlers []string // The maximum allowable timeout for relayed calls (longer timeouts are // clamped to this value). Passing zero uses the default of 2m. // This is an unstable API - breaking changes are likely. RelayMaxTimeout time.Duration // If the relay needs to connect while processing a frame, this specifies // the max connection timeout used. RelayMaxConnectionTimeout time.Duration // RelayMaxTombs is the maximum number of timed-out calls that the relay // will keep track of per-connection to avoid spurious logs // for late-arriving frames. // This is an unstable API - breaking changes are likely. RelayMaxTombs uint64 // RelayTimerVerification will disable pooling of relay timers, and instead // verify that timers are not used once they are released. // This is an unstable API - breaking changes are likely. RelayTimerVerification bool // The reporter to use for reporting stats for this channel. StatsReporter StatsReporter // TimeNow is a variable for overriding time.Now in unit tests. // Note: This is not a stable part of the API and may change. TimeNow func() time.Time // TimeTicker is a variable for overriding time.Ticker in unit tests. // Note: This is not a stable part of the API and may change. TimeTicker func(d time.Duration) *time.Ticker // MaxIdleTime controls how long we allow an idle connection to exist // before tearing it down. Must be set to non-zero if IdleCheckInterval // is set. MaxIdleTime time.Duration // IdleCheckInterval controls how often the channel runs a sweep over // all active connections to see if they can be dropped. Connections that // are idle for longer than MaxIdleTime are disconnected. If this is set to // zero (the default), idle checking is disabled. IdleCheckInterval time.Duration // Tracer is an OpenTracing Tracer used to manage distributed tracing spans. // If not set, opentracing.GlobalTracer() is used. Tracer opentracing.Tracer // Handler is an alternate handler for all inbound requests, overriding the // default handler that delegates to a subchannel. Handler Handler // SkipHandlerMethods allow users to configure TChannel server such that // requests with specified methods can be ignored by the above passed-in handler // and handled natively by TChannel. // Requests with other methods will be handled by passed-in handler. // Methods should be in the format of Service::Method. // This is useful for the gradual migration purpose. SkipHandlerMethods []string // Dialer is optional factory method which can be used for overriding // outbound connections for things like TLS handshake Dialer func(ctx context.Context, network, hostPort string) (net.Conn, error) // ConnContext runs when a connection is established, which updates // the per-connection base context. This context is used as the parent context // for incoming calls. ConnContext func(ctx context.Context, conn net.Conn) context.Context } // ChannelState is the state of a channel. type ChannelState int const ( // ChannelClient is a channel that can be used as a client. ChannelClient ChannelState = iota + 1 // ChannelListening is a channel that is listening for new connnections. ChannelListening // ChannelStartClose is a channel that has received a Close request. // The channel is no longer listening, and all new incoming connections are rejected. ChannelStartClose // ChannelInboundClosed is a channel that has drained all incoming connections, but may // have outgoing connections. All incoming calls and new outgoing calls are rejected. ChannelInboundClosed // ChannelClosed is a channel that has closed completely. ChannelClosed ) //go:generate stringer -type=ChannelState // A Channel is a bi-directional connection to the peering and routing network. // Applications can use a Channel to make service calls to remote peers via // BeginCall, or to listen for incoming calls from peers. Applications that // want to receive requests should call one of Serve or ListenAndServe // TODO(prashant): Shutdown all subchannels + peers when channel is closed. type Channel struct { channelConnectionCommon chID uint32 createdStack string commonStatsTags map[string]string connectionOptions ConnectionOptions peers *PeerList relayHost RelayHost relayMaxTimeout time.Duration relayMaxConnTimeout time.Duration relayMaxTombs uint64 relayTimerVerify bool internalHandlers *handlerMap handler Handler onPeerStatusChanged func(*Peer) dialer func(ctx context.Context, hostPort string) (net.Conn, error) connContext func(ctx context.Context, conn net.Conn) context.Context closed chan struct{} // mutable contains all the members of Channel which are mutable. mutable struct { sync.RWMutex // protects members of the mutable struct. state ChannelState peerInfo LocalPeerInfo // May be ephemeral if this is a client only channel l net.Listener // May be nil if this is a client only channel idleSweep *idleSweep conns map[uint32]*Connection } } // channelConnectionCommon is the list of common objects that both use // and can be copied directly from the channel to the connection. type channelConnectionCommon struct { log Logger relayLocal map[string]struct{} statsReporter StatsReporter tracer opentracing.Tracer subChannels *subChannelMap timeNow func() time.Time timeTicker func(time.Duration) *time.Ticker } // _nextChID is used to allocate unique IDs to every channel for debugging purposes. var _nextChID atomic.Uint32 // Tracer returns the OpenTracing Tracer for this channel. If no tracer was provided // in the configuration, returns opentracing.GlobalTracer(). Note that this approach // allows opentracing.GlobalTracer() to be initialized _after_ the channel is created. func (ccc channelConnectionCommon) Tracer() opentracing.Tracer { if ccc.tracer != nil { return ccc.tracer } return opentracing.GlobalTracer() } // NewChannel creates a new Channel. The new channel can be used to send outbound requests // to peers, but will not listen or handling incoming requests until one of ListenAndServe // or Serve is called. The local service name should be passed to serviceName. func NewChannel(serviceName string, opts *ChannelOptions) (*Channel, error) { if serviceName == "" { return nil, ErrNoServiceName } if opts == nil { opts = &ChannelOptions{} } processName := opts.ProcessName if processName == "" { processName = fmt.Sprintf("%s[%d]", filepath.Base(os.Args[0]), os.Getpid()) } logger := opts.Logger if logger == nil { logger = NullLogger } statsReporter := opts.StatsReporter if statsReporter == nil { statsReporter = NullStatsReporter } timeNow := opts.TimeNow if timeNow == nil { timeNow = time.Now } timeTicker := opts.TimeTicker if timeTicker == nil { timeTicker = time.NewTicker } chID := _nextChID.Inc() logger = logger.WithFields( LogField{"serviceName", serviceName}, LogField{"process", processName}, LogField{"chID", chID}, ) if err := opts.validateIdleCheck(); err != nil { return nil, err } // Default to dialContext if dialer is not passed in as an option dialCtx := dialContext if opts.Dialer != nil { dialCtx = func(ctx context.Context, hostPort string) (net.Conn, error) { return opts.Dialer(ctx, "tcp", hostPort) } } if opts.ConnContext == nil { opts.ConnContext = func(ctx context.Context, conn net.Conn) context.Context { return ctx } } ch := &Channel{ channelConnectionCommon: channelConnectionCommon{ log: logger, relayLocal: toStringSet(opts.RelayLocalHandlers), statsReporter: statsReporter, subChannels: &subChannelMap{}, timeNow: timeNow, timeTicker: timeTicker, tracer: opts.Tracer, }, chID: chID, connectionOptions: opts.DefaultConnectionOptions.withDefaults(), relayHost: opts.RelayHost, relayMaxTimeout: validateRelayMaxTimeout(opts.RelayMaxTimeout, logger), relayMaxConnTimeout: opts.RelayMaxConnectionTimeout, relayMaxTombs: opts.RelayMaxTombs, relayTimerVerify: opts.RelayTimerVerification, dialer: dialCtx, connContext: opts.ConnContext, closed: make(chan struct{}), } ch.peers = newRootPeerList(ch, opts.OnPeerStatusChanged).newChild() switch { case len(opts.SkipHandlerMethods) > 0 && opts.Handler != nil: sm, err := toServiceMethodSet(opts.SkipHandlerMethods) if err != nil { return nil, err } ch.handler = userHandlerWithSkip{ localHandler: channelHandler{ch}, ignoreUserHandler: sm, userHandler: opts.Handler, } case opts.Handler != nil: ch.handler = opts.Handler default: ch.handler = channelHandler{ch} } ch.mutable.peerInfo = LocalPeerInfo{ PeerInfo: PeerInfo{ ProcessName: processName, HostPort: ephemeralHostPort, IsEphemeral: true, Version: PeerVersion{ Language: "go", LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), TChannelVersion: VersionInfo, }, }, ServiceName: serviceName, } ch.mutable.state = ChannelClient ch.mutable.conns = make(map[uint32]*Connection) ch.createCommonStats() ch.internalHandlers = ch.createInternalHandlers() registerNewChannel(ch) if opts.RelayHost != nil { opts.RelayHost.SetChannel(ch) } // Start the idle connection timer. ch.mutable.idleSweep = startIdleSweep(ch, opts) return ch, nil } // ConnectionOptions returns the channel's connection options. func (ch *Channel) ConnectionOptions() *ConnectionOptions { return &ch.connectionOptions } // Serve serves incoming requests using the provided listener. // The local peer info is set synchronously, but the actual socket listening is done in // a separate goroutine. func (ch *Channel) Serve(l net.Listener) error { mutable := &ch.mutable mutable.Lock() defer mutable.Unlock() if mutable.l != nil { return errAlreadyListening } mutable.l = tnet.Wrap(l) if mutable.state != ChannelClient { return errInvalidStateForOp } mutable.state = ChannelListening mutable.peerInfo.HostPort = l.Addr().String() mutable.peerInfo.IsEphemeral = false ch.log = ch.log.WithFields(LogField{"hostPort", mutable.peerInfo.HostPort}) ch.log.Info("Channel is listening.") go ch.serve() return nil } // ListenAndServe listens on the given address and serves incoming requests. // The port may be 0, in which case the channel will use an OS assigned port // This method does not block as the handling of connections is done in a goroutine. func (ch *Channel) ListenAndServe(hostPort string) error { mutable := &ch.mutable mutable.RLock() if mutable.l != nil { mutable.RUnlock() return errAlreadyListening } l, err := net.Listen("tcp", hostPort) if err != nil { mutable.RUnlock() return err } mutable.RUnlock() return ch.Serve(l) } // Registrar is the base interface for registering handlers on either the base // Channel or the SubChannel type Registrar interface { // ServiceName returns the service name that this Registrar is for. ServiceName() string // Register registers a handler for ServiceName and the given method. Register(h Handler, methodName string) // Logger returns the logger for this Registrar. Logger() Logger // StatsReporter returns the stats reporter for this Registrar StatsReporter() StatsReporter // StatsTags returns the tags that should be used. StatsTags() map[string]string // Peers returns the peer list for this Registrar. Peers() *PeerList } // Register registers a handler for a method. // // The handler is registered with the service name used when the Channel was // created. To register a handler with a different service name, obtain a // SubChannel for that service with GetSubChannel, and Register a handler // under that. You may also use SetHandler on a SubChannel to set up a // catch-all Handler for that service. See the docs for SetHandler for more // information. // // Register panics if the channel was constructed with an alternate root // handler that does not support Register. func (ch *Channel) Register(h Handler, methodName string) { r, ok := ch.handler.(registrar) if !ok { panic("can't register handler when channel configured with alternate root handler without Register method") } r.Register(h, methodName) } // PeerInfo returns the current peer info for the channel func (ch *Channel) PeerInfo() LocalPeerInfo { ch.mutable.RLock() peerInfo := ch.mutable.peerInfo ch.mutable.RUnlock() return peerInfo } func (ch *Channel) createCommonStats() { ch.commonStatsTags = map[string]string{ "app": ch.mutable.peerInfo.ProcessName, "service": ch.mutable.peerInfo.ServiceName, } host, err := os.Hostname() if err != nil { ch.log.WithFields(ErrField(err)).Info("Channel creation failed to get host.") return } ch.commonStatsTags["host"] = host // TODO(prashant): Allow user to pass extra tags (such as cluster, version). } // GetSubChannel returns a SubChannel for the given service name. If the subchannel does not // exist, it is created. func (ch *Channel) GetSubChannel(serviceName string, opts ...SubChannelOption) *SubChannel { sub, added := ch.subChannels.getOrAdd(serviceName, ch) if added { for _, opt := range opts { opt(sub) } } return sub } // Peers returns the PeerList for the channel. func (ch *Channel) Peers() *PeerList { return ch.peers } // RootPeers returns the root PeerList for the channel, which is the sole place // new Peers are created. All children of the root list (including ch.Peers()) // automatically re-use peers from the root list and create new peers in the // root list. func (ch *Channel) RootPeers() *RootPeerList { return ch.peers.parent } // BeginCall starts a new call to a remote peer, returning an OutboundCall that can // be used to write the arguments of the call. func (ch *Channel) BeginCall(ctx context.Context, hostPort, serviceName, methodName string, callOptions *CallOptions) (*OutboundCall, error) { p := ch.RootPeers().GetOrAdd(hostPort) return p.BeginCall(ctx, serviceName, methodName, callOptions) } // serve runs the listener to accept and manage new incoming connections, blocking // until the channel is closed. func (ch *Channel) serve() { acceptBackoff := 0 * time.Millisecond for { netConn, err := ch.mutable.l.Accept() if err != nil { // Backoff from new accepts if this is a temporary error if ne, ok := err.(net.Error); ok && ne.Temporary() { if acceptBackoff == 0 { acceptBackoff = 5 * time.Millisecond } else { acceptBackoff *= 2 } if max := 1 * time.Second; acceptBackoff > max { acceptBackoff = max } ch.log.WithFields( ErrField(err), LogField{"backoff", acceptBackoff}, ).Warn("Accept error, will wait and retry.") time.Sleep(acceptBackoff) continue } else { // Only log an error if this didn't happen due to a Close. if ch.State() >= ChannelStartClose { return } ch.log.WithFields(ErrField(err)).Fatal("Unrecoverable accept error, closing server.") return } } acceptBackoff = 0 // Perform the connection handshake in a background goroutine. go func() { // Register the connection in the peer once the channel is set up. events := connectionEvents{ OnActive: ch.inboundConnectionActive, OnCloseStateChange: ch.connectionCloseStateChange, OnExchangeUpdated: ch.exchangeUpdated, } if _, err := ch.inboundHandshake(context.Background(), netConn, events); err != nil { netConn.Close() } }() } } // Ping sends a ping message to the given hostPort and waits for a response. func (ch *Channel) Ping(ctx context.Context, hostPort string) error { peer := ch.RootPeers().GetOrAdd(hostPort) conn, err := peer.GetConnection(ctx) if err != nil { return err } return conn.ping(ctx) } // Logger returns the logger for this channel. func (ch *Channel) Logger() Logger { return ch.log } // StatsReporter returns the stats reporter for this channel. func (ch *Channel) StatsReporter() StatsReporter { return ch.statsReporter } // StatsTags returns the common tags that should be used when reporting stats. // It returns a new map for each call. func (ch *Channel) StatsTags() map[string]string { m := make(map[string]string) for k, v := range ch.commonStatsTags { m[k] = v } return m } // ServiceName returns the serviceName that this channel was created for. func (ch *Channel) ServiceName() string { return ch.PeerInfo().ServiceName } // Connect creates a new outbound connection to hostPort. func (ch *Channel) Connect(ctx context.Context, hostPort string) (*Connection, error) { switch state := ch.State(); state { case ChannelClient, ChannelListening: break default: ch.log.Debugf("Connect rejecting new connection as state is %v", state) return nil, errInvalidStateForOp } // The context timeout applies to the whole call, but users may want a lower // connect timeout (e.g. for streams). if params := getTChannelParams(ctx); params != nil && params.connectTimeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, params.connectTimeout) defer cancel() } events := connectionEvents{ OnActive: ch.outboundConnectionActive, OnCloseStateChange: ch.connectionCloseStateChange, OnExchangeUpdated: ch.exchangeUpdated, } if err := ctx.Err(); err != nil { return nil, GetContextError(err) } timeout := getTimeout(ctx) tcpConn, err := ch.dialer(ctx, hostPort) if err != nil { if ne, ok := err.(net.Error); ok && ne.Timeout() { ch.log.WithFields( LogField{"remoteHostPort", hostPort}, LogField{"timeout", timeout}, ).Info("Outbound net.Dial timed out.") err = ErrTimeout } else if ctx.Err() == context.Canceled { ch.log.WithFields( LogField{"remoteHostPort", hostPort}, ).Info("Outbound net.Dial was cancelled.") err = GetContextError(ErrRequestCancelled) } else { ch.log.WithFields( ErrField(err), LogField{"remoteHostPort", hostPort}, ).Info("Outbound net.Dial failed.") } return nil, err } conn, err := ch.outboundHandshake(ctx, tcpConn, hostPort, events) if conn != nil { // It's possible that the connection we just created responds with a host:port // that is not what we tried to connect to. E.g., we may have connected to // 127.0.0.1:1234, but the returned host:port may be 10.0.0.1:1234. // In this case, the connection won't be added to 127.0.0.1:1234 peer // and so future calls to that peer may end up creating new connections. To // avoid this issue, and to avoid clients being aware of any TCP relays, we // add the connection to the intended peer. if hostPort != conn.remotePeerInfo.HostPort { conn.log.Debugf("Outbound connection host:port mismatch, adding to peer %v", conn.remotePeerInfo.HostPort) ch.addConnectionToPeer(hostPort, conn, outbound) } } return conn, err } // exchangeUpdated updates the peer heap. func (ch *Channel) exchangeUpdated(c *Connection) { if c.remotePeerInfo.HostPort == "" { // Hostport is unknown until we get init resp. return } p, ok := ch.RootPeers().Get(c.remotePeerInfo.HostPort) if !ok { return } ch.updatePeer(p) } // updatePeer updates the score of the peer and update it's position in heap as well. func (ch *Channel) updatePeer(p *Peer) { ch.peers.onPeerChange(p) ch.subChannels.updatePeer(p) p.callOnUpdateComplete() } // addConnection adds the connection to the channel's list of connection // if the channel is in a valid state to accept this connection. It returns // whether the connection was added. func (ch *Channel) addConnection(c *Connection, direction connectionDirection) bool { ch.mutable.Lock() defer ch.mutable.Unlock() if c.readState() != connectionActive { return false } switch state := ch.mutable.state; state { case ChannelClient, ChannelListening: break default: return false } ch.mutable.conns[c.connID] = c return true } func (ch *Channel) connectionActive(c *Connection, direction connectionDirection) { c.log.Debugf("New active %v connection for peer %v", direction, c.remotePeerInfo.HostPort) if added := ch.addConnection(c, direction); !added { // The channel isn't in a valid state to accept this connection, close the connection. c.close(LogField{"reason", "new active connection on closing channel"}) return } ch.addConnectionToPeer(c.remotePeerInfo.HostPort, c, direction) } func (ch *Channel) addConnectionToPeer(hostPort string, c *Connection, direction connectionDirection) { p := ch.RootPeers().GetOrAdd(hostPort) if err := p.addConnection(c, direction); err != nil { c.log.WithFields( LogField{"remoteHostPort", c.remotePeerInfo.HostPort}, LogField{"direction", direction}, ErrField(err), ).Warn("Failed to add connection to peer.") } ch.updatePeer(p) } func (ch *Channel) inboundConnectionActive(c *Connection) { ch.connectionActive(c, inbound) } func (ch *Channel) outboundConnectionActive(c *Connection) { ch.connectionActive(c, outbound) } // removeClosedConn removes a connection if it's closed. // Until a connection is fully closed, the channel must keep track of it. func (ch *Channel) removeClosedConn(c *Connection) { if c.readState() != connectionClosed { return } ch.mutable.Lock() delete(ch.mutable.conns, c.connID) ch.mutable.Unlock() } func (ch *Channel) getMinConnectionState() connectionState { minState := connectionClosed for _, c := range ch.mutable.conns { if s := c.readState(); s < minState { minState = s } } return minState } // connectionCloseStateChange is called when a connection's close state changes. func (ch *Channel) connectionCloseStateChange(c *Connection) { ch.removeClosedConn(c) if peer, ok := ch.RootPeers().Get(c.remotePeerInfo.HostPort); ok { peer.connectionCloseStateChange(c) ch.updatePeer(peer) } if c.outboundHP != "" && c.outboundHP != c.remotePeerInfo.HostPort { // Outbound connections may be in multiple peers. if peer, ok := ch.RootPeers().Get(c.outboundHP); ok { peer.connectionCloseStateChange(c) ch.updatePeer(peer) } } chState := ch.State() if chState != ChannelStartClose && chState != ChannelInboundClosed { return } ch.mutable.RLock() minState := ch.getMinConnectionState() ch.mutable.RUnlock() var updateTo ChannelState if minState >= connectionClosed { updateTo = ChannelClosed } else if minState >= connectionInboundClosed && chState == ChannelStartClose { updateTo = ChannelInboundClosed } var updatedToState ChannelState if updateTo > 0 { ch.mutable.Lock() // Recheck the state as it's possible another goroutine changed the state // from what we expected, and so we might make a stale change. if ch.mutable.state == chState { ch.mutable.state = updateTo updatedToState = updateTo } ch.mutable.Unlock() chState = updateTo } c.log.Debugf("ConnectionCloseStateChange channel state = %v connection minState = %v", chState, minState) if updatedToState == ChannelClosed { ch.onClosed() } } func (ch *Channel) onClosed() { removeClosedChannel(ch) close(ch.closed) ch.log.Infof("Channel closed.") } // Closed returns whether this channel has been closed with .Close() func (ch *Channel) Closed() bool { return ch.State() == ChannelClosed } // ClosedChan returns a channel that will close when the Channel has completely // closed. func (ch *Channel) ClosedChan() <-chan struct{} { return ch.closed } // State returns the current channel state. func (ch *Channel) State() ChannelState { ch.mutable.RLock() state := ch.mutable.state ch.mutable.RUnlock() return state } // Close starts a graceful Close for the channel. This does not happen immediately: // 1. This call closes the Listener and starts closing connections. // 2. When all incoming connections are drained, the connection blocks new outgoing calls. // 3. When all connections are drained, the channel's state is updated to Closed. func (ch *Channel) Close() { ch.Logger().Info("Channel.Close called.") var connections []*Connection var channelClosed bool func() { ch.mutable.Lock() defer ch.mutable.Unlock() if ch.mutable.state == ChannelClosed { ch.Logger().Info("Channel already closed, skipping additional Close() calls") return } if ch.mutable.l != nil { ch.mutable.l.Close() } // Stop the idle connections timer. ch.mutable.idleSweep.Stop() ch.mutable.state = ChannelStartClose if len(ch.mutable.conns) == 0 { ch.mutable.state = ChannelClosed channelClosed = true } for _, c := range ch.mutable.conns { connections = append(connections, c) } }() for _, c := range connections { c.close(LogField{"reason", "channel closing"}) } if channelClosed { ch.onClosed() } } // RelayHost returns the channel's RelayHost, if any. func (ch *Channel) RelayHost() RelayHost { return ch.relayHost } func (o *ChannelOptions) validateIdleCheck() error { if o.IdleCheckInterval > 0 && o.MaxIdleTime <= 0 { return errMaxIdleTimeNotSet } return nil } func toStringSet(ss []string) map[string]struct{} { set := make(map[string]struct{}, len(ss)) for _, s := range ss { set[s] = struct{}{} } return set } // take a list of service::method formatted string and make // the map[service::method]struct{} set func toServiceMethodSet(sms []string) (map[string]struct{}, error) { set := map[string]struct{}{} for _, sm := range sms { if len(strings.Split(sm, "::")) != 2 { return nil, fmt.Errorf("each %q value should be of service::Method format but got %q", "SkipHandlerMethods", sm) } set[sm] = struct{}{} } return set, nil } ================================================ FILE: channel_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "io/ioutil" "math" "os" "runtime" "strings" "testing" "time" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func toMap(fields LogFields) map[string]interface{} { m := make(map[string]interface{}) for _, f := range fields { m[f.Key] = f.Value } return m } func TestNewChannel(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ ProcessName: "pname", }) require.NoError(t, err, "NewChannel failed") assert.Equal(t, LocalPeerInfo{ ServiceName: "svc", PeerInfo: PeerInfo{ ProcessName: "pname", HostPort: ephemeralHostPort, IsEphemeral: true, Version: PeerVersion{ Language: "go", LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), TChannelVersion: VersionInfo, }, }, }, ch.PeerInfo(), "Wrong local peer info") } func TestLoggers(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ Logger: NewLogger(ioutil.Discard), }) require.NoError(t, err, "NewChannel failed") defer ch.Close() peerInfo := ch.PeerInfo() fields := toMap(ch.Logger().Fields()) assert.Equal(t, peerInfo.ServiceName, fields["serviceName"]) sc := ch.GetSubChannel("subch") fields = toMap(sc.Logger().Fields()) assert.Equal(t, peerInfo.ServiceName, fields["serviceName"]) assert.Equal(t, "subch", fields["subchannel"]) } func TestStats(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ Logger: NewLogger(ioutil.Discard), }) require.NoError(t, err, "NewChannel failed") defer ch.Close() hostname, err := os.Hostname() require.NoError(t, err, "Hostname failed") peerInfo := ch.PeerInfo() tags := ch.StatsTags() assert.NotNil(t, ch.StatsReporter(), "StatsReporter missing") assert.Equal(t, peerInfo.ProcessName, tags["app"], "app tag") assert.Equal(t, peerInfo.ServiceName, tags["service"], "service tag") assert.Equal(t, hostname, tags["host"], "hostname tag") sc := ch.GetSubChannel("subch") subTags := sc.StatsTags() assert.NotNil(t, sc.StatsReporter(), "StatsReporter missing") for k, v := range tags { assert.Equal(t, v, subTags[k], "subchannel missing tag %v", k) } assert.Equal(t, "subch", subTags["subchannel"], "subchannel tag missing") } func TestRelayMaxTTL(t *testing.T) { tests := []struct { max time.Duration expected time.Duration }{ {time.Second, time.Second}, {-time.Second, _defaultRelayMaxTimeout}, {0, _defaultRelayMaxTimeout}, {time.Microsecond, _defaultRelayMaxTimeout}, {math.MaxUint32 * time.Millisecond, math.MaxUint32 * time.Millisecond}, {(math.MaxUint32 + 1) * time.Millisecond, _defaultRelayMaxTimeout}, } for _, tt := range tests { ch, err := NewChannel("svc", &ChannelOptions{ RelayMaxTimeout: tt.max, }) assert.NoError(t, err, "Unexpected error when creating channel.") assert.Equal(t, ch.relayMaxTimeout, tt.expected, "Unexpected max timeout on channel.") } } func TestIsolatedSubChannelsDontSharePeers(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ Logger: NewLogger(ioutil.Discard), }) require.NoError(t, err, "NewChannel failed") defer ch.Close() sub := ch.GetSubChannel("svc-ringpop") if ch.peers != sub.peers { t.Log("Channel and subchannel don't share the same peer list.") t.Fail() } isolatedSub := ch.GetSubChannel("svc-shy-ringpop", Isolated) if ch.peers == isolatedSub.peers { t.Log("Channel and isolated subchannel share the same peer list.") t.Fail() } // Nobody knows about the peer. assert.Nil(t, ch.peers.peersByHostPort["127.0.0.1:3000"]) assert.Nil(t, sub.peers.peersByHostPort["127.0.0.1:3000"]) assert.Nil(t, isolatedSub.peers.peersByHostPort["127.0.0.1:3000"]) // Uses of the parent channel should be reflected in the subchannel, but // not the isolated subchannel. ch.Peers().Add("127.0.0.1:3000") assert.NotNil(t, ch.peers.peersByHostPort["127.0.0.1:3000"]) assert.NotNil(t, sub.peers.peersByHostPort["127.0.0.1:3000"]) assert.Nil(t, isolatedSub.peers.peersByHostPort["127.0.0.1:3000"]) } func TestChannelTracerMethod(t *testing.T) { mockTracer := mocktracer.New() ch, err := NewChannel("svc", &ChannelOptions{ Tracer: mockTracer, }) require.NoError(t, err) defer ch.Close() assert.Equal(t, mockTracer, ch.Tracer(), "expecting tracer passed at initialization") ch, err = NewChannel("svc", &ChannelOptions{}) require.NoError(t, err) defer ch.Close() assert.EqualValues(t, opentracing.GlobalTracer(), ch.Tracer(), "expecting default tracer") // because ch.Tracer() function is doing dynamic lookup, we can change global tracer origTracer := opentracing.GlobalTracer() defer opentracing.InitGlobalTracer(origTracer) opentracing.InitGlobalTracer(mockTracer) assert.Equal(t, mockTracer, ch.Tracer(), "expecting tracer set as global tracer") } func TestToServiceMethodSet(t *testing.T) { tests := []struct { desc string sms []string want map[string]struct{} wantErr string }{ { desc: "single service, single method", sms: []string{"service::Method"}, want: map[string]struct{}{ "service::Method": struct{}{}, }, }, { desc: "single service, multiple methods", sms: []string{"service::Method1", "service::Method2", "service::Method3"}, want: map[string]struct{}{ "service::Method1": struct{}{}, "service::Method2": struct{}{}, "service::Method3": struct{}{}, }, }, { desc: "invalid input", sms: []string{"notDelimitedByDoubleColons"}, wantErr: `each "SkipHandlerMethods" value should be of service::Method format but got "notDelimitedByDoubleColons"`, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { r, err := toServiceMethodSet(tt.sms) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) return } assert.Equal(t, tt.want, r) }) } } ================================================ FILE: channel_utils_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" ) // NewServer creates a new server and returns the channel, service name, and host port. func NewServer(t testing.TB, opts *testutils.ChannelOpts) (*Channel, string, string) { ch := testutils.NewServer(t, opts) peerInfo := ch.PeerInfo() return ch, peerInfo.ServiceName, peerInfo.HostPort } ================================================ FILE: channelstate_string.go ================================================ // generated by stringer -type=ChannelState; DO NOT EDIT package tchannel import "fmt" const _ChannelState_name = "ChannelClientChannelListeningChannelStartCloseChannelInboundClosedChannelClosed" var _ChannelState_index = [...]uint8{0, 13, 29, 46, 66, 79} func (i ChannelState) String() string { i -= 1 if i < 0 || i+1 >= ChannelState(len(_ChannelState_index)) { return fmt.Sprintf("ChannelState(%d)", i+1) } return _ChannelState_name[_ChannelState_index[i]:_ChannelState_index[i+1]] } ================================================ FILE: checked_frame_pool.go ================================================ package tchannel import ( "fmt" "runtime" "sync" ) // CheckedFramePoolForTest tracks gets and releases of frames, verifies that // frames aren't double released, and can be used to check for frame leaks. // As such, it is not performant, nor is it even a proper frame pool. // // It is intended to be used ONLY in tests. type CheckedFramePoolForTest struct { mu sync.Mutex allocations map[*Frame]string badRelease []string } // NewCheckedFramePoolForTest initializes a new CheckedFramePoolForTest. func NewCheckedFramePoolForTest() *CheckedFramePoolForTest { return &CheckedFramePoolForTest{ allocations: make(map[*Frame]string), } } // Get implements FramePool func (p *CheckedFramePoolForTest) Get() *Frame { p.mu.Lock() defer p.mu.Unlock() frame := NewFrame(MaxFramePayloadSize) p.allocations[frame] = recordStack() return frame } // Release implements FramePool func (p *CheckedFramePoolForTest) Release(f *Frame) { // Make sure the payload is not used after this point by clearing the frame. zeroOut(f.Payload) f.Payload = nil zeroOut(f.buffer) f.buffer = nil zeroOut(f.headerBuffer) f.headerBuffer = nil f.Header = FrameHeader{} p.mu.Lock() defer p.mu.Unlock() if _, ok := p.allocations[f]; !ok { p.badRelease = append(p.badRelease, "bad Release at "+recordStack()) return } delete(p.allocations, f) } // CheckedFramePoolForTestResult contains info on mismatched gets/releases type CheckedFramePoolForTestResult struct { BadReleases []string Unreleased []string } // HasIssues indicates whether there were any issues with gets/releases func (r CheckedFramePoolForTestResult) HasIssues() bool { return len(r.BadReleases)+len(r.Unreleased) > 0 } // CheckEmpty returns the number of unreleased frames in the pool func (p *CheckedFramePoolForTest) CheckEmpty() CheckedFramePoolForTestResult { p.mu.Lock() defer p.mu.Unlock() var badCalls []string for f, s := range p.allocations { badCalls = append(badCalls, fmt.Sprintf("frame %p: %v not released, get from: %v", f, f.Header, s)) } return CheckedFramePoolForTestResult{ Unreleased: badCalls, BadReleases: p.badRelease, } } func recordStack() string { buf := make([]byte, 4096) runtime.Stack(buf, false /* all */) return string(buf) } func zeroOut(bs []byte) { for i := range bs { bs[i] = 0 } } ================================================ FILE: checked_frame_pool_test.go ================================================ package tchannel import ( "strings" "testing" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go/testutils/goroutines" ) func TestCheckedFramePoolForTest(t *testing.T) { tests := []struct { msg string operations func(pool *CheckedFramePoolForTest) wantHasIssues bool wantBadAllocations int wantBadReleases int }{ { msg: "no bad releases or leaks", operations: func(pool *CheckedFramePoolForTest) { for i := 0; i < 10; i++ { pool.Release(pool.Get()) } }, }, { msg: "frames are leaked", operations: func(pool *CheckedFramePoolForTest) { for i := 0; i < 10; i++ { pool.Release(pool.Get()) } for i := 0; i < 10; i++ { _ = pool.Get() } }, wantHasIssues: true, wantBadAllocations: 10, }, { msg: "frames are double released", operations: func(pool *CheckedFramePoolForTest) { for i := 0; i < 10; i++ { pool.Release(pool.Get()) } f := pool.Get() pool.Release(f) pool.Release(f) }, wantHasIssues: true, wantBadReleases: 1, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { pool := NewCheckedFramePoolForTest() tt.operations(pool) results := pool.CheckEmpty() assert.Equal(t, tt.wantHasIssues, results.HasIssues(), "Unexpected HasIssues() state") assert.Equal(t, tt.wantBadAllocations, len(results.Unreleased), "Unexpected allocs") assert.Equal(t, tt.wantBadReleases, len(results.BadReleases), "Unexpected bad releases") }) } } func CheckFramePoolIsEmpty(t testing.TB, pool *CheckedFramePoolForTest) { t.Helper() stacks := goroutines.GetAll() if result := pool.CheckEmpty(); result.HasIssues() { if len(result.Unreleased) > 0 { t.Errorf("Frame pool has %v unreleased frames, errors:\n%v\nStacks:%v", len(result.Unreleased), strings.Join(result.Unreleased, "\n"), stacks) } if len(result.BadReleases) > 0 { t.Errorf("Frame pool has %v bad releases, errors:\n%v\nStacks:%v", len(result.BadReleases), strings.Join(result.BadReleases, "\n"), stacks) } } } ================================================ FILE: checksum.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "hash" "hash/crc32" "sync" ) var checksumPools [checksumCount]sync.Pool // A ChecksumType is a checksum algorithm supported by TChannel for checksumming call bodies type ChecksumType byte const ( // ChecksumTypeNone indicates no checksum is included in the message ChecksumTypeNone ChecksumType = 0 // ChecksumTypeCrc32 indicates the message checksum is calculated using crc32 ChecksumTypeCrc32 ChecksumType = 1 // ChecksumTypeFarmhash indicates the message checksum is calculated using Farmhash ChecksumTypeFarmhash ChecksumType = 2 // ChecksumTypeCrc32C indicates the message checksum is calculated using crc32c ChecksumTypeCrc32C ChecksumType = 3 checksumCount = 4 ) func init() { crc32CastagnoliTable := crc32.MakeTable(crc32.Castagnoli) ChecksumTypeNone.pool().New = func() interface{} { return nullChecksum{} } ChecksumTypeCrc32.pool().New = func() interface{} { return newHashChecksum(ChecksumTypeCrc32, crc32.NewIEEE()) } ChecksumTypeCrc32C.pool().New = func() interface{} { return newHashChecksum(ChecksumTypeCrc32C, crc32.New(crc32CastagnoliTable)) } // TODO: Implement farm hash. ChecksumTypeFarmhash.pool().New = func() interface{} { return nullChecksum{} } } // ChecksumSize returns the size in bytes of the checksum calculation func (t ChecksumType) ChecksumSize() int { switch t { case ChecksumTypeNone: return 0 case ChecksumTypeCrc32, ChecksumTypeCrc32C: return crc32.Size case ChecksumTypeFarmhash: return 4 default: return 0 } } // pool returns the sync.Pool used to pool checksums for this type. func (t ChecksumType) pool() *sync.Pool { return &checksumPools[int(t)] } // New creates a new Checksum of the given type func (t ChecksumType) New() Checksum { s := t.pool().Get().(Checksum) s.Reset() return s } // Release puts a Checksum back in the pool. func (t ChecksumType) Release(checksum Checksum) { t.pool().Put(checksum) } // A Checksum calculates a running checksum against a bytestream type Checksum interface { // TypeCode returns the type of this checksum TypeCode() ChecksumType // Size returns the size of the calculated checksum Size() int // Add adds bytes to the checksum calculation Add(b []byte) []byte // Sum returns the current checksum value Sum() []byte // Release puts a Checksum back in the pool. Release() // Reset resets the checksum state to the default 0 value. Reset() } // No checksum type nullChecksum struct{} // TypeCode returns the type of the checksum func (c nullChecksum) TypeCode() ChecksumType { return ChecksumTypeNone } // Size returns the size of the checksum data, in the case the null checksum this is zero func (c nullChecksum) Size() int { return 0 } // Add adds a byteslice to the checksum calculation func (c nullChecksum) Add(b []byte) []byte { return nil } // Sum returns the current checksum calculation func (c nullChecksum) Sum() []byte { return nil } // Release puts a Checksum back in the pool. func (c nullChecksum) Release() { c.TypeCode().Release(c) } // Reset resets the checksum state to the default 0 value. func (c nullChecksum) Reset() {} // Hash Checksum type hashChecksum struct { checksumType ChecksumType hash hash.Hash sumCache []byte } func newHashChecksum(t ChecksumType, hash hash.Hash) *hashChecksum { return &hashChecksum{ checksumType: t, hash: hash, sumCache: make([]byte, 0, 4), } } // TypeCode returns the type of the checksum func (h *hashChecksum) TypeCode() ChecksumType { return h.checksumType } // Size returns the size of the checksum data func (h *hashChecksum) Size() int { return h.hash.Size() } // Add adds a byte slice to the checksum calculation func (h *hashChecksum) Add(b []byte) []byte { h.hash.Write(b); return h.Sum() } // Sum returns the current value of the checksum calculation func (h *hashChecksum) Sum() []byte { return h.hash.Sum(h.sumCache) } // Release puts a Checksum back in the pool. func (h *hashChecksum) Release() { h.TypeCode().Release(h) } // Reset resets the checksum state to the default 0 value. func (h *hashChecksum) Reset() { h.hash.Reset() } // noReleaseChecksum overrides .Release() with a NOOP so that the checksum won't // be released by the fragmentingWriter when it is managed externally, e.g. by the // relayer type noReleaseChecksum struct { Checksum } func (n *noReleaseChecksum) Release() {} ================================================ FILE: close_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "math/rand" "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/goroutines" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "golang.org/x/net/context" ) type channelState struct { testServer *testutils.TestServer closeCh chan struct{} closed bool } func makeCall(client *Channel, server *testutils.TestServer) error { ctx, cancel := NewContext(time.Second) defer cancel() _, _, _, err := raw.Call(ctx, client, server.HostPort(), server.ServiceName(), "test", nil, nil) return err } func assertStateChangesTo(t testing.TB, ch *Channel, state ChannelState) { var lastState ChannelState require.True(t, testutils.WaitFor(time.Second, func() bool { lastState = ch.State() return lastState == state }), "Channel state is %v expected %v", lastState, state) } func TestCloseOnlyListening(t *testing.T) { ch := testutils.NewServer(t, nil) // If there are no connections, then the channel should close immediately. ch.Close() assert.Equal(t, ChannelClosed, ch.State()) assert.True(t, ch.Closed(), "Channel should be closed") } func TestCloseNewClient(t *testing.T) { ch := testutils.NewClient(t, nil) // If there are no connections, then the channel should close immediately. ch.Close() assert.Equal(t, ChannelClosed, ch.State()) assert.True(t, ch.Closed(), "Channel should be closed") } func ignoreError(h *testHandler) Handler { return raw.Wrap(onErrorTestHandler{ testHandler: h, onError: func(_ context.Context, err error) {}, }) } func TestCloseAfterTimeout(t *testing.T) { // Disable log verfication since connections are closed after a timeout // and the relay might still be reading/writing to the connection. // TODO: Ideally, we only disable log verification on the relay. opts := testutils.NewOpts().DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testHandler := newTestHandler(t) ts.Register(ignoreError(testHandler), "block") ctx, cancel := NewContext(100 * time.Millisecond) defer cancel() // Make a call, wait for it to timeout. clientCh := ts.NewClient(nil) _, _, _, err := raw.Call(ctx, clientCh, ts.HostPort(), ts.ServiceName(), "block", nil, nil) require.Equal(t, ErrTimeout, err, "Expected call to timeout") // The client channel should also close immediately. clientCh.Close() assertStateChangesTo(t, clientCh, ChannelClosed) assert.True(t, clientCh.Closed(), "Channel should be closed") }) } func TestRelayCloseTimeout(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). // this is a relay-specific test. DisableLogVerification() // we're causing errors on purpose. opts.DefaultConnectionOptions.MaxCloseTime = 100 * time.Millisecond testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { gotCall := make(chan struct{}) unblock := make(chan struct{}) defer close(unblock) testutils.RegisterEcho(ts.Server(), func() { close(gotCall) <-unblock }) clientCh := ts.NewClient(opts) // Start a call in the background, since it will block go func() { ctx, cancel := NewContext(10 * time.Second) defer cancel() _, _, _, err := raw.Call(ctx, clientCh, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) require.Error(t, err) assert.Equal(t, ErrCodeNetwork, GetSystemErrorCode(err), "expect network error from relay closing connection on timeout") }() <-gotCall ts.Relay().Close() // The relay should close within the timeout. <-ts.Relay().ClosedChan() }) } func TestRaceExchangesWithClose(t *testing.T) { var wg sync.WaitGroup opts := testutils.NewOpts().DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { var ( server = ts.Server() gotCall = make(chan struct{}) completeCall = make(chan struct{}) ) testutils.RegisterFunc(server, "dummy", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{}, nil }) testutils.RegisterEcho(server, func() { close(gotCall) <-completeCall }) client := ts.NewClient(opts) defer client.Close() callDone := make(chan struct{}) go func() { // n.b. Use a longer context here; server shutdown is inherently // nondeterministic, and now that it's blocking on channel and // connection closure, it can take anywhere from 0-2s to fully // close all of its internals. ctx, cancel := context.WithTimeout( context.Background(), testutils.Timeout(5*time.Second), ) defer cancel() assert.NoError( t, testutils.CallEchoWithContext( ctx, client, ts.HostPort(), server.ServiceName(), &raw.Args{}, ), "Echo failed", ) close(callDone) }() // Wait until the server recieves a call, so it has an active inbound. <-gotCall // Start a bunch of clients to trigger races between connecting and close. var closed atomic.Bool for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() // We don't use ts.NewClient here to avoid data races. c := testutils.NewClient(t, opts) defer c.Close() if closed.Load() { return } ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() if err := c.Ping(ctx, ts.HostPort()); err != nil { return } if closed.Load() { return } raw.Call(ctx, c, ts.HostPort(), server.ServiceName(), "dummy", nil, nil) }() } // Now try to close the channel, it should block since there's active exchanges. server.Close() closed.Store(true) // n.b. As it's shutting down, server state can be in any of the // outlined states below. It doesn't matter which specific state // it's in, as long as we're verifying that it's at least in the // process of shutting down. var ( timeout = time.After(testutils.Timeout(time.Second)) validState = func() bool { switch ts.Server().State() { case ChannelStartClose, ChannelInboundClosed, ChannelClosed: return true default: return false } } ) ticker := time.NewTicker(25 * time.Millisecond) defer ticker.Stop() for !validState() { select { case <-ticker.C: case <-timeout: require.FailNow( t, "server state did not transition as expected: %v", ts.Server().State(), ) } } closed.Store(true) close(completeCall) <-callDone }) // Wait for all calls to complete wg.Wait() } // TestCloseStress ensures that once a Channel is closed, it cannot be reached. func TestCloseStress(t *testing.T) { CheckStress(t) const numHandlers = 5 handler := &swapper{t} var lock sync.RWMutex var wg sync.WaitGroup var channels []*channelState // Start numHandlers servers, and don't close the connections till they are signalled. for i := 0; i < numHandlers; i++ { wg.Add(1) go func() { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(handler), "test") chState := &channelState{ testServer: ts, closeCh: make(chan struct{}), } lock.Lock() channels = append(channels, chState) lock.Unlock() wg.Done() // Wait for a close signal. <-chState.closeCh // Lock until the connection is closed. lock.Lock() chState.closed = true }) }() } // Wait till all the channels have been registered. wg.Wait() // Start goroutines to make calls until the test has ended. testEnded := make(chan struct{}) for i := 0; i < 10; i++ { go func() { for { select { case <-testEnded: return default: // Keep making requests till the test ends. } // Get 2 random channels and make a call from one to the other. lock.RLock() chState1 := channels[rand.Intn(len(channels))] chState2 := channels[rand.Intn(len(channels))] if chState1 == chState2 { lock.RUnlock() continue } // Grab a read lock to make sure channels aren't closed while we call. ch1Closed := chState1.closed ch2Closed := chState2.closed err := makeCall(chState1.testServer.NewClient(nil), chState2.testServer) lock.RUnlock() if ch1Closed || ch2Closed { assert.Error( t, err, "Call from %v (%v) to %v (%v) should fail", chState1.testServer.ServiceName(), chState1.testServer.HostPort(), chState2.testServer.ServiceName(), chState2.testServer.HostPort(), ) } else { assert.NoError( t, err, "Call from %v (%v) to %v (%v) should not fail", chState1.testServer.ServiceName(), chState1.testServer.HostPort(), chState2.testServer.ServiceName(), chState2.testServer.HostPort(), ) } } }() } // Kill connections till all of the connections are dead. for i := 0; i < numHandlers; i++ { time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) channels[i].closeCh <- struct{}{} } } type closeSemanticsTest struct { *testing.T isolated bool } func (t *closeSemanticsTest) makeServer(name string) (*Channel, chan struct{}) { ch := testutils.NewServer(t.T, &testutils.ChannelOpts{ServiceName: name}) c := make(chan struct{}) testutils.RegisterFunc(ch, "stream", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { <-c return &raw.Res{}, nil }) testutils.RegisterFunc(ch, "call", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{}, nil }) return ch, c } func (t *closeSemanticsTest) withNewClient(f func(ch *Channel)) { ch := testutils.NewClient(t.T, &testutils.ChannelOpts{ServiceName: "client"}) f(ch) ch.Close() } func (t *closeSemanticsTest) startCall(from *Channel, to *Channel, method string) (*OutboundCall, error) { ctx, _ := NewContext(time.Second) var call *OutboundCall var err error toPeer := to.PeerInfo() if t.isolated { sc := from.GetSubChannel(toPeer.ServiceName, Isolated) sc.Peers().Add(toPeer.HostPort) call, err = sc.BeginCall(ctx, method, nil) } else { call, err = from.BeginCall(ctx, toPeer.HostPort, toPeer.ServiceName, method, nil) } return call, err } func (t *closeSemanticsTest) call(from *Channel, to *Channel) error { call, err := t.startCall(from, to, "call") if err == nil { _, _, _, err = raw.WriteArgs(call, nil, nil) } return err } func (t *closeSemanticsTest) callStream(from *Channel, to *Channel) <-chan struct{} { c := make(chan struct{}) call, err := t.startCall(from, to, "stream") require.NoError(t, err, "stream call failed to start") require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(nil), "write arg2") require.NoError(t, NewArgWriter(call.Arg3Writer()).Write(nil), "write arg3") go func() { var d []byte assert.NoError(t, NewArgReader(call.Response().Arg2Reader()).Read(&d), "read arg2 from %v to %v", from.PeerInfo(), to.PeerInfo()) assert.NoError(t, NewArgReader(call.Response().Arg3Reader()).Read(&d), "read arg3") c <- struct{}{} }() return c } func (t *closeSemanticsTest) runTest() { s1, s1C := t.makeServer("s1") s2, s2C := t.makeServer("s2") // Make a call from s1 -> s2, and s2 -> s1 call1 := t.callStream(s1, s2) call2 := t.callStream(s2, s1) // s1 and s2 are both open, so calls to it should be successful. t.withNewClient(func(ch *Channel) { require.NoError(t, t.call(ch, s1), "failed to call s1") require.NoError(t, t.call(ch, s2), "failed to call s2") }) require.NoError(t, t.call(s1, s2), "call s1 -> s2 failed") require.NoError(t, t.call(s2, s1), "call s2 -> s1 failed") // Close s1, should no longer be able to call it. s1.Close() assert.Equal(t, ChannelStartClose, s1.State()) t.withNewClient(func(ch *Channel) { assert.Error(t, t.call(ch, s1), "closed channel should not accept incoming calls") require.NoError(t, t.call(ch, s2), "closed channel with pending incoming calls should allow outgoing calls") }) // Even an existing connection (e.g. from s2) should fail. // TODO: this will fail until the peer is shared. if !assert.Equal(t, ErrChannelClosed, t.call(s2, s1), "closed channel should not accept incoming calls") { t.Errorf("err %v", t.call(s2, s1)) } require.Error(t, t.call(s1, s2), "closed channel with pending incoming calls disallows outgoing calls") // Once the incoming connection is drained, outgoing calls should fail. s1C <- struct{}{} <-call2 assertStateChangesTo(t.T, s1, ChannelInboundClosed) require.Error(t, t.call(s1, s2), "closed channel with no pending incoming calls should not allow outgoing calls") // Now the channel should be completely closed as there are no pending connections. s2C <- struct{}{} <-call1 assertStateChangesTo(t.T, s1, ChannelClosed) // Close s2 so we don't leave any goroutines running. s2.Close() } func TestCloseSemantics(t *testing.T) { // We defer the check as we want it to run after the SetTimeout clears the timeout. defer goroutines.VerifyNoLeaks(t, nil) defer testutils.SetTimeout(t, 2*time.Second)() ct := &closeSemanticsTest{t, false /* isolated */} ct.runTest() } func TestCloseSemanticsIsolated(t *testing.T) { // We defer the check as we want it to run after the SetTimeout clears the timeout. defer goroutines.VerifyNoLeaks(t, nil) defer testutils.SetTimeout(t, 2*time.Second)() ct := &closeSemanticsTest{t, true /* isolated */} ct.runTest() } func TestCloseSingleChannel(t *testing.T) { ch := testutils.NewServer(t, nil) var connected sync.WaitGroup var completed sync.WaitGroup blockCall := make(chan struct{}) testutils.RegisterFunc(ch, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { connected.Done() <-blockCall return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, }, nil }) for i := 0; i < 10; i++ { connected.Add(1) completed.Add(1) go func() { ctx, cancel := NewContext(time.Second) defer cancel() peerInfo := ch.PeerInfo() _, _, _, err := raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "echo", nil, nil) assert.NoError(t, err, "Call failed") completed.Done() }() } // Wait for all calls to connect before triggerring the Close (so they do not fail). connected.Wait() ch.Close() // Unblock the calls, and wait for all the calls to complete. close(blockCall) completed.Wait() // Once all calls are complete, the channel should be closed. assertStateChangesTo(t, ch, ChannelClosed) goroutines.VerifyNoLeaks(t, nil) } func TestCloseOneSide(t *testing.T) { ch1 := testutils.NewServer(t, &testutils.ChannelOpts{ServiceName: "client"}) ch2 := testutils.NewServer(t, &testutils.ChannelOpts{ServiceName: "server"}) connected := make(chan struct{}) completed := make(chan struct{}) blockCall := make(chan struct{}) testutils.RegisterFunc(ch2, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { connected <- struct{}{} <-blockCall return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, }, nil }) go func() { ctx, cancel := NewContext(time.Second) defer cancel() ch2Peer := ch2.PeerInfo() _, _, _, err := raw.Call(ctx, ch1, ch2Peer.HostPort, ch2Peer.ServiceName, "echo", nil, nil) assert.NoError(t, err, "Call failed") completed <- struct{}{} }() // Wait for connected before calling Close. <-connected ch1.Close() // Now unblock the call and wait for the call to complete. close(blockCall) <-completed // Once the call completes, the channel should be closed. assertStateChangesTo(t, ch1, ChannelClosed) // We need to close all open TChannels before verifying blocked goroutines. ch2.Close() goroutines.VerifyNoLeaks(t, nil) } // TestCloseSendError tests that system errors are not attempted to be sent when // a connection is closed, and ensures there's no race conditions such as the error // frame being added to the channel just as it is closed. func TestCloseSendError(t *testing.T) { var ( closed atomic.Uint32 counter atomic.Uint32 ) opts := testutils.NewOpts().DisableLogVerification() serverCh := testutils.NewServer(t, opts) testutils.RegisterEcho(serverCh, func() { if counter.Inc() > 10 { // Close the server in a goroutine to possibly trigger more race conditions. go func() { closed.Inc() serverCh.Close() }() } }) clientCh := testutils.NewClient(t, opts) // Create a connection that will be shared. require.NoError(t, testutils.Ping(clientCh, serverCh), "Ping from client to server failed") var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { time.Sleep(time.Duration(rand.Intn(1000)) * time.Microsecond) err := testutils.CallEcho(clientCh, serverCh.PeerInfo().HostPort, serverCh.ServiceName(), nil) if err != nil && closed.Load() == 0 { t.Errorf("Call failed: %v", err) } wg.Done() }() } // Wait for all the goroutines to end wg.Wait() clientCh.Close() goroutines.VerifyNoLeaks(t, nil) } ================================================ FILE: codecov.yml ================================================ coverage: range: 75..100 round: down precision: 2 status: project: default: enabled: yes target: 85% if_not_found: success if_ci_failed: error ignore: - "*_string.go" ================================================ FILE: conn_leak_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "io/ioutil" "runtime" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/require" ) // This is a regression test for https://github.com/uber/tchannel-go/issues/643 // We want to ensure that once a connection is closed, there are no references // to the closed connection, and the GC frees the connection. // We use `runtime.SetFinalizer` to detect whether the GC has freed the object. // However, finalizers cannot be set on objects with circular references, // so we cannot set a finalizer on the connection, but instead set a finalizer // on a field of the connection which has the same lifetime. The connection // logger is unique per connection and does not have circular references // so we can use the logger, but need a pointer for `runtime.SetFinalizer`. // loggerPtr is a Logger implementation that uses a pointer unlike other // TChannel loggers. type loggerPtr struct { Logger } func (l *loggerPtr) WithFields(fields ...LogField) Logger { return &loggerPtr{l.Logger.WithFields(fields...)} } func TestPeerConnectionLeaks(t *testing.T) { // Disable log verification since we want to set our own logger. opts := testutils.NewOpts().NoRelay().DisableLogVerification() opts.Logger = &loggerPtr{NullLogger} connFinalized := make(chan struct{}) setFinalizer := func(p *Peer) { ctx, cancel := NewContext(time.Second) defer cancel() conn, err := p.GetConnection(ctx) require.NoError(t, err, "Failed to get connection") runtime.SetFinalizer(conn.Logger(), func(interface{}) { close(connFinalized) }) } testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s2Opts := testutils.NewOpts().SetServiceName("s2") s2Opts.Logger = NewLogger(ioutil.Discard) s2 := ts.NewServer(s2Opts) // Set a finalizer to detect when the connection from s1 -> s2 is freed. peer := ts.Server().Peers().GetOrAdd(s2.PeerInfo().HostPort) setFinalizer(peer) // Close s2, so that the connection in s1 to s2 is released. s2.Close() closed := testutils.WaitFor(3*time.Second, s2.Closed) require.True(t, closed, "s2 didn't close") // Trigger the GC which will call the finalizer, and ensure // that the connection logger was finalized. finalized := testutils.WaitFor(3*time.Second, func() bool { runtime.GC() select { case <-connFinalized: return true default: return false } }) require.True(t, finalized, "Connection was not freed") }) } ================================================ FILE: connection.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "crypto/tls" "errors" "fmt" "io" "net" "strings" "sync" "syscall" "time" "github.com/uber/tchannel-go/tos" "go.uber.org/atomic" "golang.org/x/net/context" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) const ( // CurrentProtocolVersion is the current version of the TChannel protocol // supported by this stack CurrentProtocolVersion = 0x02 // DefaultConnectTimeout is the default timeout used by net.Dial, if no timeout // is specified in the context. DefaultConnectTimeout = 5 * time.Second // DefaultConnectionBufferSize is the default size for the connection's read //and write channels. DefaultConnectionBufferSize = 512 ) // PeerVersion contains version related information for a specific peer. // These values are extracted from the init headers. type PeerVersion struct { Language string `json:"language"` LanguageVersion string `json:"languageVersion"` TChannelVersion string `json:"tchannelVersion"` } // PeerInfo contains information about a TChannel peer type PeerInfo struct { // The host and port that can be used to contact the peer, as encoded by net.JoinHostPort HostPort string `json:"hostPort"` // The logical process name for the peer, used for only for logging / debugging ProcessName string `json:"processName"` // IsEphemeral returns whether the remote host:port is ephemeral (e.g. not listening). IsEphemeral bool `json:"isEphemeral"` // Version returns the version information for the remote peer. Version PeerVersion `json:"version"` } func (p PeerInfo) String() string { return fmt.Sprintf("%s(%s)", p.HostPort, p.ProcessName) } // IsEphemeralHostPort returns whether the connection is from an ephemeral host:port. func (p PeerInfo) IsEphemeralHostPort() bool { return p.IsEphemeral } // LocalPeerInfo adds service name to the peer info, only required for the local peer. type LocalPeerInfo struct { PeerInfo // ServiceName is the service name for the local peer. ServiceName string `json:"serviceName"` } func (p LocalPeerInfo) String() string { return fmt.Sprintf("%v: %v", p.ServiceName, p.PeerInfo) } var ( // ErrConnectionClosed is returned when a caller performs an method // on a closed connection ErrConnectionClosed = errors.New("connection is closed") // ErrSendBufferFull is returned when a message cannot be sent to the // peer because the frame sending buffer has become full. Typically // this indicates that the connection is stuck and writes have become // backed up ErrSendBufferFull = errors.New("connection send buffer is full, cannot send frame") // ErrConnectionNotReady is no longer used. ErrConnectionNotReady = errors.New("connection is not yet ready") errNoSyscallConn = errors.New("no syscall.RawConn available") ) // errConnectionInvalidState is returned when the connection is in an unknown state. type errConnectionUnknownState struct { site string state connectionState } func (e errConnectionUnknownState) Error() string { return fmt.Sprintf("connection is in unknown state: %v at %v", e.state, e.site) } // SendBufferSizeOverride is used for overriding per-process send buffer channel size for a // connection, using process name prefix matching. type SendBufferSizeOverride struct { ProcessNamePrefix string SendBufferSize int } // ConnectionOptions are options that control the behavior of a Connection type ConnectionOptions struct { // The frame pool, allowing better management of frame buffers. Defaults to using raw heap. FramePool FramePool // NOTE: This is deprecated and not used for anything. RecvBufferSize int // The size of send channel buffers. Defaults to 512. SendBufferSize int // Per-process name prefix override for SendBufferSize // Note that order matters, if there are multiple matches, the first one is used. SendBufferSizeOverrides []SendBufferSizeOverride // The type of checksum to use when sending messages. ChecksumType ChecksumType // ToS class name marked on outbound packets. TosPriority tos.ToS // HealthChecks configures active connection health checking for this channel. // By default, health checks are not enabled. HealthChecks HealthCheckOptions // MaxCloseTime controls how long we allow a connection to complete pending // calls before shutting down. Only used if it is non-zero. MaxCloseTime time.Duration // PropagateCancel enables cancel messages to cancel contexts. // By default, cancel messages are ignored. // This only affects inbounds (servers handling calls). PropagateCancel bool // SendCancelOnContextCanceled enables sending cancel messages // when a request context is canceled before receiving a response. // This only affects outbounds (clients making calls). SendCancelOnContextCanceled bool } // connectionEvents are the events that can be triggered by a connection. type connectionEvents struct { // OnActive is called when a connection becomes active. OnActive func(c *Connection) // OnCloseStateChange is called when a connection that is closing changes state. OnCloseStateChange func(c *Connection) // OnExchangeUpdated is called when a message exchange added or removed. OnExchangeUpdated func(c *Connection) } // Connection represents a connection to a remote peer. type Connection struct { channelConnectionCommon connID uint32 connDirection connectionDirection opts ConnectionOptions conn net.Conn sysConn syscall.RawConn // may be nil if conn cannot be converted localPeerInfo LocalPeerInfo remotePeerInfo PeerInfo sendCh chan *Frame stopCh chan struct{} state connectionState stateMut sync.RWMutex inbound *messageExchangeSet outbound *messageExchangeSet internalHandlers *handlerMap handler Handler nextMessageID atomic.Uint32 events connectionEvents commonStatsTags map[string]string relay *Relayer baseContext context.Context // outboundHP is the host:port we used to create this outbound connection. // It may not match remotePeerInfo.HostPort, in which case the connection is // added to peers for both host:ports. For inbound connections, this is empty. outboundHP string // closeNetworkCalled is used to avoid errors from being logged // when this side closes a connection. closeNetworkCalled atomic.Bool // stoppedExchanges is atomically set when exchanges are stopped due to error. stoppedExchanges atomic.Bool // remotePeerAddress is used as a cache for remote peer address parsed into individual // components that can be used to set peer tags on OpenTracing Span. remotePeerAddress peerAddressComponents // healthCheckCtx/Quit are used to stop health checks. healthCheckCtx context.Context healthCheckQuit context.CancelFunc healthCheckDone chan struct{} healthCheckHistory *healthHistory // lastActivity{Read,Write} is used to track how long the connection has been // idle for the recieve and send connections respectively. (unix time, nano) lastActivityRead atomic.Int64 lastActivityWrite atomic.Int64 } type peerAddressComponents struct { port uint16 ipv4 uint32 ipv6 string hostname string } // _nextConnID is used to allocate unique IDs to every connection for debugging purposes. var _nextConnID atomic.Uint32 type connectionState int const ( // Connection is fully active connectionActive connectionState = iota + 1 // Connection is starting to close; new incoming requests are rejected, outbound // requests are allowed to proceed connectionStartClose // Connection has finished processing all active inbound, and is // waiting for outbound requests to complete or timeout connectionInboundClosed // Connection is fully closed connectionClosed ) //go:generate stringer -type=connectionState func getTimeout(ctx context.Context) time.Duration { deadline, ok := ctx.Deadline() if !ok { return DefaultConnectTimeout } return deadline.Sub(time.Now()) } func (co ConnectionOptions) withDefaults() ConnectionOptions { if co.ChecksumType == ChecksumTypeNone { co.ChecksumType = ChecksumTypeCrc32 } if co.FramePool == nil { co.FramePool = DefaultFramePool } if co.SendBufferSize <= 0 { co.SendBufferSize = DefaultConnectionBufferSize } co.HealthChecks = co.HealthChecks.withDefaults() return co } func (co ConnectionOptions) getSendBufferSize(processName string) int { for _, override := range co.SendBufferSizeOverrides { if strings.HasPrefix(processName, override.ProcessNamePrefix) { return override.SendBufferSize } } return co.SendBufferSize } func (ch *Channel) setConnectionTosPriority(tosPriority tos.ToS, c net.Conn) error { tcpAddr, isTCP := c.RemoteAddr().(*net.TCPAddr) if !isTCP { return nil } // Handle dual stack listeners and set Traffic Class. var err error switch ip := tcpAddr.IP; { case ip.To16() != nil && ip.To4() == nil: err = ipv6.NewConn(c).SetTrafficClass(int(tosPriority)) case ip.To4() != nil: err = ipv4.NewConn(c).SetTOS(int(tosPriority)) } return err } func (ch *Channel) newConnection(baseCtx context.Context, conn net.Conn, initialID uint32, outboundHP string, remotePeer PeerInfo, remotePeerAddress peerAddressComponents, events connectionEvents) *Connection { opts := ch.connectionOptions.withDefaults() connID := _nextConnID.Inc() connDirection := inbound log := ch.log.WithFields(LogFields{ {"connID", connID}, {"localAddr", conn.LocalAddr().String()}, {"remoteAddr", conn.RemoteAddr().String()}, {"remoteHostPort", remotePeer.HostPort}, {"remoteIsEphemeral", remotePeer.IsEphemeral}, {"remoteProcess", remotePeer.ProcessName}, }...) if outboundHP != "" { connDirection = outbound log = log.WithFields(LogField{"outboundHP", outboundHP}) } log = log.WithFields(LogField{"connectionDirection", connDirection}) peerInfo := ch.PeerInfo() timeNow := ch.timeNow().UnixNano() c := &Connection{ channelConnectionCommon: ch.channelConnectionCommon, connID: connID, conn: conn, sysConn: getSysConn(conn, log), connDirection: connDirection, opts: opts, state: connectionActive, sendCh: make(chan *Frame, opts.getSendBufferSize(remotePeer.ProcessName)), stopCh: make(chan struct{}), localPeerInfo: peerInfo, remotePeerInfo: remotePeer, remotePeerAddress: remotePeerAddress, outboundHP: outboundHP, inbound: newMessageExchangeSet(log, messageExchangeSetInbound), outbound: newMessageExchangeSet(log, messageExchangeSetOutbound), internalHandlers: ch.internalHandlers, handler: ch.handler, events: events, commonStatsTags: ch.commonStatsTags, healthCheckHistory: newHealthHistory(), lastActivityRead: *atomic.NewInt64(timeNow), lastActivityWrite: *atomic.NewInt64(timeNow), baseContext: ch.connContext(baseCtx, conn), } if tosPriority := opts.TosPriority; tosPriority > 0 { if err := ch.setConnectionTosPriority(tosPriority, conn); err != nil { log.WithFields(ErrField(err)).Error("Failed to set ToS priority.") } } c.nextMessageID.Store(initialID) c.log = log c.outbound.onCancel = c.onCancel c.inbound.onRemoved = c.checkExchanges c.outbound.onRemoved = c.checkExchanges c.inbound.onAdded = c.onExchangeAdded c.outbound.onAdded = c.onExchangeAdded if ch.RelayHost() != nil { c.relay = NewRelayer(ch, c) } // Connections are activated as soon as they are created. c.callOnActive() go c.readFrames(connID) go c.writeFrames(connID) return c } func (c *Connection) onCancel(msgID uint32) { if !c.opts.SendCancelOnContextCanceled { return } cancelMsg := &cancelMessage{ id: msgID, message: ErrRequestCancelled.Error(), } if err := c.sendMessage(cancelMsg); err != nil { c.connectionError("send cancel", err) } } func (c *Connection) onExchangeAdded() { c.callOnExchangeChange() } // IsActive returns whether this connection is in an active state. func (c *Connection) IsActive() bool { return c.readState() == connectionActive } func (c *Connection) callOnActive() { log := c.log if remoteVersion := c.remotePeerInfo.Version; remoteVersion != (PeerVersion{}) { log = log.WithFields(LogFields{ {"remotePeerLanguage", remoteVersion.Language}, {"remotePeerLanguageVersion", remoteVersion.LanguageVersion}, {"remotePeerTChannelVersion", remoteVersion.TChannelVersion}, }...) } log.Debug("Created new active connection.") if f := c.events.OnActive; f != nil { f(c) } if c.opts.HealthChecks.enabled() { c.healthCheckCtx, c.healthCheckQuit = context.WithCancel(context.Background()) c.healthCheckDone = make(chan struct{}) go c.healthCheck(c.connID) } } func (c *Connection) callOnCloseStateChange() { if f := c.events.OnCloseStateChange; f != nil { f(c) } } func (c *Connection) callOnExchangeChange() { if f := c.events.OnExchangeUpdated; f != nil { f(c) } } // ping sends a ping message and waits for a ping response. func (c *Connection) ping(ctx context.Context) error { req := &pingReq{id: c.NextMessageID()} mex, err := c.outbound.newExchange(ctx, c.outboundCtxCancel, c.opts.FramePool, req.messageType(), req.ID(), 1) if err != nil { return c.connectionError("create ping exchange", err) } defer c.outbound.removeExchange(req.ID()) if err := c.sendMessage(req); err != nil { return c.connectionError("send ping", err) } return c.recvMessage(ctx, &pingRes{}, mex) } // handlePingRes calls registered ping handlers. func (c *Connection) handlePingRes(frame *Frame) bool { if err := c.outbound.forwardPeerFrame(frame); err != nil { c.log.WithFields(LogField{"response", frame.Header}).Warn("Unexpected ping response.") return true } // ping req is waiting for this frame, and will release it. return false } // handlePingReq responds to the pingReq message with a pingRes. func (c *Connection) handlePingReq(frame *Frame) { if state := c.readState(); state != connectionActive { c.protocolError(frame.Header.ID, errConnNotActive{"ping on incoming", state}) return } pingRes := &pingRes{id: frame.Header.ID} if err := c.sendMessage(pingRes); err != nil { c.connectionError("send pong", err) } } // sendMessage sends a standalone message (typically a control message) func (c *Connection) sendMessage(msg message) error { frame := c.opts.FramePool.Get() if err := frame.write(msg); err != nil { c.opts.FramePool.Release(frame) return err } select { case c.sendCh <- frame: return nil default: return ErrSendBufferFull } } // recvMessage blocks waiting for a standalone response message (typically a // control message) func (c *Connection) recvMessage(ctx context.Context, msg message, mex *messageExchange) error { frame, err := mex.recvPeerFrameOfType(msg.messageType()) if err != nil { if err, ok := err.(errorMessage); ok { return err.AsSystemError() } return err } err = frame.read(msg) c.opts.FramePool.Release(frame) return err } // RemotePeerInfo returns the peer info for the remote peer. func (c *Connection) RemotePeerInfo() PeerInfo { return c.remotePeerInfo } // NextMessageID reserves the next available message id for this connection func (c *Connection) NextMessageID() uint32 { return c.nextMessageID.Inc() } // SendSystemError sends an error frame for the given system error. func (c *Connection) SendSystemError(id uint32, span Span, err error) (sendErr error) { // Allocate an error frame to be sent over the connection. A nil is // returned if the frame was successfully sent, otherwise an error is // returned, and we must release the error frame back to the pool. frame := c.opts.FramePool.Get() defer func() { if sendErr != nil { c.opts.FramePool.Release(frame) } }() if err := frame.write(&errorMessage{ id: id, errCode: GetSystemErrorCode(err), tracing: span, message: GetSystemErrorMessage(err), }); err != nil { // This shouldn't happen - it means writing the errorMessage is broken. c.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, LogField{"id", id}, ErrField(err), ).Warn("Couldn't create outbound frame.") return fmt.Errorf("failed to create outbound error frame: %v", err) } // When sending errors, we hold the state rlock to ensure that sendCh is not closed // as we are sending the frame. return c.withStateRLock(func() error { // Errors cannot be sent if the connection has been closed. if c.state == connectionClosed { c.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, LogField{"id", id}, ).Info("Could not send error frame on closed connection.") return fmt.Errorf("failed to send error frame, connection state %v", c.state) } select { case c.sendCh <- frame: // Good to go return nil default: // If the send buffer is full, log and return an error. } c.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, LogField{"id", id}, ErrField(err), ).Warn("Couldn't send outbound frame.") return fmt.Errorf("failed to send error frame, buffer full") }) } func (c *Connection) logConnectionError(site string, err error) error { errCode := ErrCodeNetwork if err == io.EOF { c.log.Debugf("Connection got EOF") } else { logger := c.log.WithFields( LogField{"site", site}, ErrField(err), ) if se, ok := err.(SystemError); ok && se.Code() != ErrCodeNetwork { errCode = se.Code() logger.Error("Connection error.") } else if ne, ok := err.(net.Error); ok && ne.Timeout() { logger.Warn("Connection error due to timeout.") } else { logger.Info("Connection error.") } } return NewWrappedSystemError(errCode, err) } // connectionError handles a connection level error func (c *Connection) connectionError(site string, err error) error { var closeLogFields LogFields if err == io.EOF { closeLogFields = LogFields{{"reason", "network connection EOF"}} } else { closeLogFields = LogFields{ {"reason", "connection error"}, ErrField(err), } } c.stopHealthCheck() err = c.logConnectionError(site, err) c.close(closeLogFields...) // On any connection error, notify the exchanges of this error. if c.stoppedExchanges.CAS(false, true) { c.outbound.stopExchanges(err) c.inbound.stopExchanges(err) } // checkExchanges will close the connection due to stoppedExchanges. c.checkExchanges() return err } func (c *Connection) protocolError(id uint32, err error) error { c.log.WithFields(ErrField(err)).Warn("Protocol error.") sysErr := NewWrappedSystemError(ErrCodeProtocol, err) c.SendSystemError(id, Span{}, sysErr) // Don't close the connection until the error has been sent. c.close( LogField{"reason", "protocol error"}, ErrField(err), ) // On any connection error, notify the exchanges of this error. if c.stoppedExchanges.CAS(false, true) { c.outbound.stopExchanges(sysErr) c.inbound.stopExchanges(sysErr) } return sysErr } // withStateLock performs an action with the connection state mutex locked func (c *Connection) withStateLock(f func() error) error { c.stateMut.Lock() err := f() c.stateMut.Unlock() return err } // withStateRLock performs an action with the connection state mutex rlocked. func (c *Connection) withStateRLock(f func() error) error { c.stateMut.RLock() err := f() c.stateMut.RUnlock() return err } func (c *Connection) readState() connectionState { c.stateMut.RLock() state := c.state c.stateMut.RUnlock() return state } // readFrames is the loop that reads frames from the network connection and // dispatches to the appropriate handler. Run within its own goroutine to // prevent overlapping reads on the socket. Most handlers simply send the // incoming frame to a channel; the init handlers are a notable exception, // since we cannot process new frames until the initialization is complete. func (c *Connection) readFrames(_ uint32) { headerBuf := make([]byte, FrameHeaderSize) handleErr := func(err error) { if !c.closeNetworkCalled.Load() { c.connectionError("read frames", err) } else { c.log.Debugf("Ignoring error after connection was closed: %v", err) } } for { // Read the header, avoid allocating the frame till we know the size // we need to allocate. if _, err := io.ReadFull(c.conn, headerBuf); err != nil { handleErr(err) return } frame := c.opts.FramePool.Get() if err := frame.ReadBody(headerBuf, c.conn); err != nil { handleErr(err) c.opts.FramePool.Release(frame) return } c.updateLastActivityRead(frame) var releaseFrame bool if c.relay == nil { releaseFrame = c.handleFrameNoRelay(frame) } else { releaseFrame = c.handleFrameRelay(frame) } if releaseFrame { c.opts.FramePool.Release(frame) } } } func (c *Connection) handleFrameRelay(frame *Frame) bool { if frame.Header.messageType == messageTypeCancel && !c.opts.PropagateCancel { // If cancel propagation is disabled, don't do anything for this frame. if c.log.Enabled(LogLevelDebug) { c.log.Debugf("Ignoring cancel in relay for %v", frame.Header.ID) } return true } switch msgType := frame.Header.messageType; msgType { case messageTypeCallReq, messageTypeCallReqContinue, messageTypeCallRes, messageTypeCallResContinue, messageTypeError, messageTypeCancel: shouldRelease, err := c.relay.Relay(frame) if err != nil { c.log.WithFields( ErrField(err), LogField{"header", frame.Header}, LogField{"remotePeer", c.remotePeerInfo}, ).Error("Failed to relay frame.") } return shouldRelease default: return c.handleFrameNoRelay(frame) } } func (c *Connection) handleFrameNoRelay(frame *Frame) bool { releaseFrame := true // call req and call res messages may not want the frame released immediately. switch frame.Header.messageType { case messageTypeCallReq: releaseFrame = c.handleCallReq(frame) case messageTypeCallReqContinue: releaseFrame = c.handleCallReqContinue(frame) case messageTypeCallRes: releaseFrame = c.handleCallRes(frame) case messageTypeCallResContinue: releaseFrame = c.handleCallResContinue(frame) case messageTypePingReq: c.handlePingReq(frame) case messageTypePingRes: releaseFrame = c.handlePingRes(frame) case messageTypeError: releaseFrame = c.handleError(frame) case messageTypeCancel: releaseFrame = c.handleCancel(frame) default: // TODO(mmihic): Log and close connection with protocol error c.log.WithFields( LogField{"header", frame.Header}, LogField{"remotePeer", c.remotePeerInfo}, ).Error("Received unexpected frame.") } return releaseFrame } // writeFrames is the main loop that pulls frames from the send channel and // writes them to the connection. func (c *Connection) writeFrames(_ uint32) { defer func() { <-c.stopCh // Drain and release any remaining frames in sendCh for best-effort // reduction in leaked frames for len(c.sendCh) > 0 { c.opts.FramePool.Release(<-c.sendCh) } }() for { select { case f := <-c.sendCh: if c.log.Enabled(LogLevelDebug) { c.log.Debugf("Writing frame %s", f.Header) } c.updateLastActivityWrite(f) err := f.WriteOut(c.conn) c.opts.FramePool.Release(f) if err != nil { c.connectionError("write frames", err) return } case <-c.stopCh: // If there are frames in sendCh, we want to drain them. if len(c.sendCh) > 0 { continue } // Close the network once we're no longer writing frames. c.closeNetwork() return } } } // updateLastActivityRead marks when the last message was received on the channel. // This is used for monitoring idle connections and timing them out. func (c *Connection) updateLastActivityRead(frame *Frame) { if isMessageTypeCall(frame) { c.lastActivityRead.Store(c.timeNow().UnixNano()) } } // updateLastActivityWrite marks when the last message was sent on the channel. // This is used for monitoring idle connections and timing them out. func (c *Connection) updateLastActivityWrite(frame *Frame) { if isMessageTypeCall(frame) { c.lastActivityWrite.Store(c.timeNow().UnixNano()) } } // hasPendingCalls returns whether there's any pending inbound or outbound calls on this connection. func (c *Connection) hasPendingCalls() bool { if c.inbound.count() > 0 || c.outbound.count() > 0 { return true } if !c.relay.canClose() { return true } return false } // checkExchanges is called whenever an exchange is removed, and when Close is called. func (c *Connection) checkExchanges() { c.callOnExchangeChange() moveState := func(fromState, toState connectionState) bool { err := c.withStateLock(func() error { if c.state != fromState { return errors.New("") } c.state = toState return nil }) return err == nil } curState := c.readState() origState := curState if curState != connectionClosed && c.stoppedExchanges.Load() { if moveState(curState, connectionClosed) { curState = connectionClosed } } if curState == connectionStartClose { if !c.relay.canClose() { return } if c.inbound.count() == 0 && moveState(connectionStartClose, connectionInboundClosed) { curState = connectionInboundClosed } } if curState == connectionInboundClosed { // Safety check -- this should never happen since we already did the check // when transitioning to connectionInboundClosed. if !c.relay.canClose() { c.relay.logger.Error("Relay can't close even though state is InboundClosed.") return } if c.outbound.count() == 0 && moveState(connectionInboundClosed, connectionClosed) { curState = connectionClosed } } if curState != origState { // If the connection is closed, we can notify writeFrames to stop which // closes the underlying network connection. We never close sendCh to avoid // races causing panics, see 93ef5c112c8b321367ae52d2bd79396e2e874f31 if curState == connectionClosed { close(c.stopCh) } c.log.WithFields( LogField{"newState", curState}, ).Debug("Connection state updated during shutdown.") c.callOnCloseStateChange() } } func (c *Connection) close(fields ...LogField) error { c.log.WithFields(fields...).Debug("Connection closing.") // Update the state which will start blocking incoming calls. if err := c.withStateLock(func() error { switch s := c.state; s { case connectionActive: c.state = connectionStartClose default: return fmt.Errorf("connection must be Active to Close, but it is %v", s) } return nil }); err != nil { return err } // Set a read deadline with any close timeout. This will cause a i/o timeout // if the connection isn't closed by then. if c.opts.MaxCloseTime > 0 { c.conn.SetReadDeadline(c.timeNow().Add(c.opts.MaxCloseTime)) } c.log.WithFields( LogField{"newState", c.readState()}, ).Debug("Connection state updated in Close.") c.callOnCloseStateChange() // Check all in-flight requests to see whether we can transition the Close state. c.checkExchanges() return nil } // Close starts a graceful Close which will first reject incoming calls, reject outgoing calls // before finally marking the connection state as closed. func (c *Connection) Close() error { return c.close(LogField{"reason", "user initiated"}) } // closeNetwork closes the network connection and all network-related channels. // This should only be done in response to a fatal connection or protocol // error, or after all pending frames have been sent. func (c *Connection) closeNetwork() { // NB(mmihic): The sender goroutine will exit once the connection is // closed; no need to close the send channel (and closing the send // channel would be dangerous since other goroutine might be sending) c.log.Debugf("Closing underlying network connection") c.stopHealthCheck() c.closeNetworkCalled.Store(true) if err := c.conn.Close(); err != nil { c.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, ErrField(err), ).Warn("Couldn't close connection to peer.") } } // getLastActivityReadTime returns the timestamp of the last frame read, // excluding pings. If no frames were transmitted yet, it will return the time // this connection was created. func (c *Connection) getLastActivityReadTime() time.Time { return time.Unix(0, c.lastActivityRead.Load()) } // getLastActivityWriteTime returns the timestamp of the last frame written, // excluding pings. If no frames were transmitted yet, it will return the time // this connection was created. func (c *Connection) getLastActivityWriteTime() time.Time { return time.Unix(0, c.lastActivityWrite.Load()) } func getSysConn(conn net.Conn, log Logger) syscall.RawConn { var ( connSyscall syscall.Conn ok bool ) switch v := conn.(type) { case syscall.Conn: connSyscall = v ok = true case *tls.Conn: connSyscall, ok = v.NetConn().(syscall.Conn) } if !ok { log.WithFields(LogField{"connectionType", fmt.Sprintf("%T", conn)}). Error("Connection does not implement SyscallConn.") return nil } sysConn, err := connSyscall.SyscallConn() if err != nil { log.WithFields(ErrField(err)).Error("Could not get SyscallConn.") return nil } return sysConn } func isMessageTypeCall(frame *Frame) bool { // Pings are ignored for last activity. switch frame.Header.messageType { case messageTypeCallReq, messageTypeCallReqContinue, messageTypeCallRes, messageTypeCallResContinue, messageTypeError: return true } return false } ================================================ FILE: connection_bench_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "runtime" "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/streadway/quantile" "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) const benchService = "bench-server" type benchmarkHandler struct{} func (h *benchmarkHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg2: args.Arg3, Arg3: args.Arg2, }, nil } func (h *benchmarkHandler) OnError(ctx context.Context, err error) { } type latencyTracker struct { sync.Mutex started time.Time estimator *quantile.Estimator } func newLatencyTracker() *latencyTracker { return &latencyTracker{ estimator: quantile.New( quantile.Unknown(0.01), quantile.Known(0.50, 0.01), quantile.Known(0.95, 0.001), quantile.Known(0.99, 0.0005), quantile.Known(1.0, 0.0005), ), started: time.Now(), } } func (lt *latencyTracker) addLatency(d time.Duration) { lt.Lock() lt.estimator.Add(float64(d)) lt.Unlock() } func (lt *latencyTracker) report(t testing.TB) { duration := time.Since(lt.started) lt.Lock() t.Logf("%6v calls, %5.0f RPS (%v per call). Latency: Average: %v P95: %v P99: %v P100: %v", lt.estimator.Samples(), float64(lt.estimator.Samples())/float64(duration)*float64(time.Second), duration/time.Duration(lt.estimator.Samples()), time.Duration(lt.estimator.Get(0.50)), time.Duration(lt.estimator.Get(0.95)), time.Duration(lt.estimator.Get(0.99)), time.Duration(lt.estimator.Get(1.0)), ) lt.Unlock() } func setupServer(t testing.TB) *Channel { serverCh := testutils.NewServer(t, testutils.NewOpts().SetServiceName("bench-server")) handler := &benchmarkHandler{} serverCh.Register(raw.Wrap(handler), "echo") return serverCh } type benchmarkConfig struct { numCalls int numServers int numClients int workersPerClient int numBytes int } func benchmarkCallsN(b *testing.B, c benchmarkConfig) { var ( clients []*Channel servers []*Channel ) lt := newLatencyTracker() if c.numBytes == 0 { c.numBytes = 100 } data := testutils.RandBytes(c.numBytes) // Set up clients and servers. for i := 0; i < c.numServers; i++ { servers = append(servers, setupServer(b)) } for i := 0; i < c.numClients; i++ { clients = append(clients, testutils.NewClient(b, nil)) for _, s := range servers { clients[i].Peers().Add(s.PeerInfo().HostPort) // Initialize a connection ctx, cancel := NewContext(50 * time.Millisecond) assert.NoError(b, clients[i].Ping(ctx, s.PeerInfo().HostPort), "Initial ping failed") cancel() } } // Make calls from clients to the servers call := func(sc *SubChannel) { ctx, cancel := NewContext(50 * time.Millisecond) start := time.Now() _, _, _, err := raw.CallSC(ctx, sc, "echo", nil, data) duration := time.Since(start) cancel() if assert.NoError(b, err, "Call failed") { lt.addLatency(duration) } } reqsLeft := testutils.Decrementor(c.numCalls) clientWorker := func(client *Channel, clientNum, workerNum int) { sc := client.GetSubChannel(benchService) for reqsLeft.Single() { call(sc) } } clientRunner := func(client *Channel, clientNum int) { testutils.RunN(c.workersPerClient, func(i int) { clientWorker(client, clientNum, i) }) } lt = newLatencyTracker() defer lt.report(b) b.ResetTimer() testutils.RunN(c.numClients, func(i int) { clientRunner(clients[i], i) }) } func BenchmarkCallsSerial(b *testing.B) { benchmarkCallsN(b, benchmarkConfig{ numCalls: b.N, numServers: 1, numClients: 1, workersPerClient: 1, }) } func BenchmarkCallsConcurrentServer(b *testing.B) { benchmarkCallsN(b, benchmarkConfig{ numCalls: b.N, numServers: 1, numClients: runtime.GOMAXPROCS(0), workersPerClient: 1, }) } func BenchmarkCallsConcurrentClient(b *testing.B) { parallelism := runtime.GOMAXPROCS(0) benchmarkCallsN(b, benchmarkConfig{ numCalls: b.N, numServers: parallelism, numClients: 1, workersPerClient: parallelism, }) } ================================================ FILE: connection_direction.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "fmt" type connectionDirection int const ( inbound connectionDirection = iota + 1 outbound ) func (d connectionDirection) String() string { switch d { case inbound: return "inbound" case outbound: return "outbound" default: return fmt.Sprintf("connectionDirection(%v)", int(d)) } } ================================================ FILE: connection_internal_test.go ================================================ // Copyright (c) 2020 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "crypto/tls" "net" "net/http/httptest" "syscall" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type errSyscallConn struct { net.Conn } func (c errSyscallConn) SyscallConn() (syscall.RawConn, error) { return nil, assert.AnError } func TestGetSysConn(t *testing.T) { t.Run("no SyscallConn", func(t *testing.T) { loggerBuf := &bytes.Buffer{} logger := NewLogger(loggerBuf) type dummyConn struct { net.Conn } syscallConn := getSysConn(dummyConn{}, logger) require.Nil(t, syscallConn, "expected no syscall.RawConn to be returned") assert.Contains(t, loggerBuf.String(), "Connection does not implement SyscallConn", "missing log") assert.Contains(t, loggerBuf.String(), "dummyConn", "missing type in log") }) t.Run("SyscallConn returns error", func(t *testing.T) { loggerBuf := &bytes.Buffer{} logger := NewLogger(loggerBuf) syscallConn := getSysConn(errSyscallConn{}, logger) require.Nil(t, syscallConn, "expected no syscall.RawConn to be returned") assert.Contains(t, loggerBuf.String(), "Could not get SyscallConn", "missing log") assert.Contains(t, loggerBuf.String(), assert.AnError.Error(), "missing error in log") }) t.Run("SyscallConn is successful", func(t *testing.T) { loggerBuf := &bytes.Buffer{} logger := NewLogger(loggerBuf) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "Failed to listen") defer ln.Close() conn, err := net.Dial("tcp", ln.Addr().String()) require.NoError(t, err, "failed to dial") defer conn.Close() sysConn := getSysConn(conn, logger) require.NotNil(t, sysConn) assert.Empty(t, loggerBuf.String(), "expected no logs on success") }) t.Run("SyscallConn is successful with TLS", func(t *testing.T) { var ( loggerBuf = &bytes.Buffer{} logger = NewLogger(loggerBuf) server = httptest.NewTLSServer(nil) ) defer server.Close() conn, err := tls.Dial("tcp", server.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true}) require.NoError(t, err, "failed to dial") defer conn.Close() sysConn := getSysConn(conn, logger) require.NotNil(t, sysConn) assert.Empty(t, loggerBuf.String(), "expected no logs on success") }) t.Run("no SyscallConn - nil net.Conn", func(t *testing.T) { var ( loggerBuf = &bytes.Buffer{} logger = NewLogger(loggerBuf) syscallConn = getSysConn(nil /* conn */, logger) ) require.Nil(t, syscallConn, "expected no syscall.RawConn to be returned") assert.Contains(t, loggerBuf.String(), "Connection does not implement SyscallConn", "missing log") assert.Contains(t, loggerBuf.String(), "{connectionType }", "missing type in log") }) t.Run("no SyscallConn - TLS with no net.Conn", func(t *testing.T) { var ( loggerBuf = &bytes.Buffer{} logger = NewLogger(loggerBuf) syscallConn = getSysConn(&tls.Conn{}, logger) ) require.Nil(t, syscallConn, "expected no syscall.RawConn to be returned") assert.Contains(t, loggerBuf.String(), "Connection does not implement SyscallConn", "missing log") assert.Contains(t, loggerBuf.String(), "{connectionType *tls.Conn}", "missing type in log") }) } ================================================ FILE: connection_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "crypto/tls" "errors" "fmt" "io" "math" "net" "os" "runtime" "strings" "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/relay/relaytest" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/testreader" "github.com/uber/tchannel-go/tos" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // Values used in tests const ( inbound = 0 outbound = 1 ) var ( testArg2 = []byte("Header in arg2") testArg3 = []byte("Body in arg3") ) type testHandler struct { sync.Mutex t testing.TB format Format caller string blockErr chan error } func newTestHandler(t testing.TB) *testHandler { return &testHandler{t: t, blockErr: make(chan error, 1)} } func (h *testHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { h.Lock() h.format = args.Format h.caller = args.Caller h.Unlock() assert.Equal(h.t, args.Caller, CurrentCall(ctx).CallerName()) switch args.Method { case "block": <-ctx.Done() h.blockErr <- ctx.Err() return &raw.Res{ IsErr: true, }, nil case "echo": return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, }, nil case "busy": return &raw.Res{ SystemErr: ErrServerBusy, }, nil case "app-error": return &raw.Res{ IsErr: true, }, nil } return nil, errors.New("unknown method") } func (h *testHandler) OnError(ctx context.Context, err error) { stack := make([]byte, 4096) runtime.Stack(stack, false /* all */) h.t.Errorf("testHandler got error: %v stack:\n%s", err, stack) } func writeFlushStr(w ArgWriter, d string) error { if _, err := io.WriteString(w, d); err != nil { return err } return w.Flush() } func isTosPriority(c net.Conn, tosPriority tos.ToS) (bool, error) { var connTosPriority int var err error switch ip := c.RemoteAddr().(*net.TCPAddr).IP; { case ip.To16() != nil && ip.To4() == nil: connTosPriority, err = ipv6.NewConn(c).TrafficClass() case ip.To4() != nil: connTosPriority, err = ipv4.NewConn(c).TOS() } return connTosPriority == int(tosPriority), err } func getErrorFrame(t testing.TB) *Frame { var errFrame *Frame server := testutils.NewServer(t, testutils.NewOpts().DisableLogVerification()) defer server.Close() frameRelay, cancel := testutils.FrameRelay(t, server.PeerInfo().HostPort, func(outgoing bool, f *Frame) *Frame { if strings.Contains(f.Header.String(), "Error") { errFrame = f } return f }) defer cancel() testutils.CallEcho(server, frameRelay, "unknown", nil) require.NotNil(t, errFrame, "Failed to get error frame") return errFrame } func TestRoundTrip(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { handler := newTestHandler(t) ts.Register(raw.Wrap(handler), "echo") ctx, cancel := NewContext(time.Second) defer cancel() call, err := ts.Server().BeginCall(ctx, ts.HostPort(), ts.ServiceName(), "echo", &CallOptions{Format: JSON}) require.NoError(t, err) assert.NotEmpty(t, call.RemotePeer().HostPort) assert.Equal(t, ts.Server().PeerInfo(), call.LocalPeer(), "Unexpected local peer") require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(testArg2)) require.NoError(t, NewArgWriter(call.Arg3Writer()).Write(testArg3)) var respArg2 []byte require.NoError(t, NewArgReader(call.Response().Arg2Reader()).Read(&respArg2)) assert.Equal(t, testArg2, []byte(respArg2)) var respArg3 []byte require.NoError(t, NewArgReader(call.Response().Arg3Reader()).Read(&respArg3)) assert.Equal(t, testArg3, []byte(respArg3)) assert.Equal(t, JSON, handler.format) assert.Equal(t, ts.ServiceName(), handler.caller) assert.Equal(t, JSON, call.Response().Format(), "response Format should match request Format") }) } func TestDefaultFormat(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { handler := newTestHandler(t) ts.Register(raw.Wrap(handler), "echo") ctx, cancel := NewContext(time.Second) defer cancel() arg2, arg3, resp, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "echo", testArg2, testArg3) require.Nil(t, err) require.Equal(t, testArg2, arg2) require.Equal(t, testArg3, arg3) require.Equal(t, Raw, handler.format) assert.Equal(t, Raw, resp.Format(), "response Format should match request Format") }) } func TestRemotePeer(t *testing.T) { wantVersion := PeerVersion{ Language: "go", LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), TChannelVersion: VersionInfo, } tests := []struct { name string remote func(testing.TB, *testutils.TestServer) *Channel expectedFn func(*RuntimeState, *testutils.TestServer) PeerInfo }{ { name: "ephemeral client", remote: func(t testing.TB, ts *testutils.TestServer) *Channel { return ts.NewClient(nil) }, expectedFn: func(state *RuntimeState, ts *testutils.TestServer) PeerInfo { return PeerInfo{ HostPort: state.RootPeers[ts.HostPort()].OutboundConnections[0].LocalHostPort, IsEphemeral: true, ProcessName: state.LocalPeer.ProcessName, Version: wantVersion, } }, }, { name: "listening server", remote: func(t testing.TB, ts *testutils.TestServer) *Channel { return ts.NewServer(nil) }, expectedFn: func(state *RuntimeState, ts *testutils.TestServer) PeerInfo { return PeerInfo{ HostPort: state.LocalPeer.HostPort, IsEphemeral: false, ProcessName: state.LocalPeer.ProcessName, Version: wantVersion, } }, }, } ctx, cancel := NewContext(time.Second) defer cancel() for _, tt := range tests { opts := testutils.NewOpts().SetServiceName("fake-service").NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { remote := tt.remote(t, ts) defer remote.Close() gotPeer := make(chan PeerInfo, 1) ts.RegisterFunc("test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { gotPeer <- CurrentCall(ctx).RemotePeer() assert.Equal(t, ts.Server().PeerInfo(), CurrentCall(ctx).LocalPeer()) return &raw.Res{}, nil }) _, _, _, err := raw.Call(ctx, remote, ts.HostPort(), ts.Server().ServiceName(), "test", nil, nil) assert.NoError(t, err, "%v: Call failed", tt.name) expected := tt.expectedFn(remote.IntrospectState(nil), ts) assert.Equal(t, expected, <-gotPeer, "%v: RemotePeer mismatch", tt.name) }) } } func TestReuseConnection(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() // Since we're specifically testing that connections between hosts are re-used, // we can't interpose a relay in this test. s1Opts := testutils.NewOpts().SetServiceName("s1").NoRelay() testutils.WithTestServer(t, s1Opts, func(t testing.TB, ts *testutils.TestServer) { ch2 := ts.NewServer(&testutils.ChannelOpts{ServiceName: "s2"}) hostPort2 := ch2.PeerInfo().HostPort defer ch2.Close() ts.Register(raw.Wrap(newTestHandler(t)), "echo") ch2.Register(raw.Wrap(newTestHandler(t)), "echo") outbound, err := ts.Server().BeginCall(ctx, hostPort2, "s2", "echo", nil) require.NoError(t, err) outboundConn, outboundNetConn := OutboundConnection(outbound) // Try to make another call at the same time, should reuse the same connection. outbound2, err := ts.Server().BeginCall(ctx, hostPort2, "s2", "echo", nil) require.NoError(t, err) outbound2Conn, _ := OutboundConnection(outbound) assert.Equal(t, outboundConn, outbound2Conn) // Wait for the connection to be marked as active in ch2. assert.True(t, testutils.WaitFor(time.Second, func() bool { return ch2.IntrospectState(nil).NumConnections > 0 }), "ch2 does not have any active connections") // When ch2 tries to call the test server, it should reuse the existing // inbound connection the test server. Of course, this only works if the // test server -> ch2 call wasn't relayed. outbound3, err := ch2.BeginCall(ctx, ts.HostPort(), "s1", "echo", nil) require.NoError(t, err) _, outbound3NetConn := OutboundConnection(outbound3) assert.Equal(t, outboundNetConn.RemoteAddr(), outbound3NetConn.LocalAddr()) assert.Equal(t, outboundNetConn.LocalAddr(), outbound3NetConn.RemoteAddr()) // Ensure all calls can complete in parallel. var wg sync.WaitGroup for _, call := range []*OutboundCall{outbound, outbound2, outbound3} { wg.Add(1) go func(call *OutboundCall) { defer wg.Done() resp1, resp2, _, err := raw.WriteArgs(call, []byte("arg2"), []byte("arg3")) require.NoError(t, err) assert.Equal(t, resp1, []byte("arg2"), "result does match argument") assert.Equal(t, resp2, []byte("arg3"), "result does match argument") }(call) } wg.Wait() }) } func TestPing(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() errFrame := getErrorFrame(t) var returnErr bool frameRelay, close := testutils.FrameRelay(t, ts.HostPort(), func(outgoing bool, f *Frame) *Frame { if !outgoing && returnErr { errFrame.Header.ID = f.Header.ID f = errFrame } return f }) defer close() clientCh := ts.NewClient(nil) defer clientCh.Close() require.NoError(t, clientCh.Ping(ctx, frameRelay)) conn, err := clientCh.RootPeers().GetOrAdd(frameRelay).GetConnection(ctx) require.NoError(t, err, "Failed to get connection") returnErr = true require.Error(t, conn.Ping(ctx), "Expect error from error frame") require.True(t, conn.IsActive(), "Connection should still be active after error frame") returnErr = false require.NoError(t, conn.Ping(ctx), "Ping should succeed") }) } func TestBadRequest(t *testing.T) { // ch will log an error when it receives a request for an unknown handler. opts := testutils.NewOpts().AddLogFilter("Couldn't find handler.", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "Noone", []byte("Headers"), []byte("Body")) require.NotNil(t, err) assert.Equal(t, ErrCodeBadRequest, GetSystemErrorCode(err)) calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "Noone").Failed("bad-request").End() ts.AssertRelayStats(calls) }) } func TestNoTimeout(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(newTestHandler(t)), "Echo") ctx := context.Background() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), "svc", "Echo", []byte("Headers"), []byte("Body")) assert.Equal(t, ErrTimeoutRequired, err) ts.AssertRelayStats(relaytest.NewMockStats()) }) } func TestCancelled(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(newTestHandler(t)), "echo") ctx, cancel := NewContext(time.Second) // Make a call first to make sure we have a connection. // We want to test the BeginCall path. _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "echo", []byte("Headers"), []byte("Body")) assert.NoError(t, err, "Call failed") // Now cancel the context. cancel() _, _, _, err = raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "echo", []byte("Headers"), []byte("Body")) assert.Equal(t, ErrRequestCancelled, err, "Unexpected error when making call with canceled context") }) } func TestNoServiceNaming(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), "", "Echo", []byte("Headers"), []byte("Body")) assert.Equal(t, ErrNoServiceName, err) ts.AssertRelayStats(relaytest.NewMockStats()) }) } func TestServerBusy(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(ErrorHandlerFunc(func(ctx context.Context, call *InboundCall) error { if _, err := raw.ReadArgs(call); err != nil { return err } return ErrServerBusy }), "busy") ctx, cancel := NewContext(time.Second) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "busy", []byte("Arg2"), []byte("Arg3")) require.NotNil(t, err) assert.Equal(t, ErrCodeBusy, GetSystemErrorCode(err), "err: %v", err) calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "busy").Failed("busy").End() ts.AssertRelayStats(calls) }) } func TestUnexpectedHandlerError(t *testing.T) { opts := testutils.NewOpts(). AddLogFilter("Unexpected handler error", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ts.Register(ErrorHandlerFunc(func(ctx context.Context, call *InboundCall) error { if _, err := raw.ReadArgs(call); err != nil { return err } return fmt.Errorf("nope") }), "nope") ctx, cancel := NewContext(time.Second) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "nope", []byte("Arg2"), []byte("Arg3")) require.NotNil(t, err) assert.Equal(t, ErrCodeUnexpected, GetSystemErrorCode(err), "err: %v", err) calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "nope").Failed("unexpected-error").End() ts.AssertRelayStats(calls) }) } type onErrorTestHandler struct { *testHandler onError func(ctx context.Context, err error) } func (h onErrorTestHandler) OnError(ctx context.Context, err error) { h.onError(ctx, err) } func TestTimeout(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { // onError may be called when the block call tries to write the call response. onError := func(ctx context.Context, err error) { assert.Equal(t, ErrTimeout, err, "onError err should be ErrTimeout") assert.Equal(t, context.DeadlineExceeded, ctx.Err(), "Context should timeout") } testHandler := onErrorTestHandler{newTestHandler(t), onError} ts.Register(raw.Wrap(testHandler), "block") ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "block", []byte("Arg2"), []byte("Arg3")) assert.Equal(t, ErrTimeout, err) // Verify the server-side receives an error from the context. select { case err := <-testHandler.blockErr: assert.Equal(t, context.DeadlineExceeded, err, "Server should have received timeout") case <-time.After(time.Second): t.Errorf("Server did not receive call, may need higher timeout") } calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "block").Failed("timeout").End() ts.AssertRelayStats(calls) }) } func TestServerClientCancellation(t *testing.T) { opts := testutils.NewOpts() opts.DefaultConnectionOptions.SendCancelOnContextCanceled = true opts.DefaultConnectionOptions.PropagateCancel = true serverStats := newRecordingStatsReporter() opts.StatsReporter = serverStats testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { callReceived := make(chan struct{}) testutils.RegisterFunc(ts.Server(), "ctxWait", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { require.NoError(t, ctx.Err(), "context valid before cancellation") close(callReceived) <-ctx.Done() assert.Equal(t, context.Canceled, ctx.Err()) return nil, assert.AnError }) ctx, cancel := NewContext(3 * time.Second) defer cancel() // Wait for the call to be recieved by the server before cancelling the context. go func() { <-callReceived cancel() }() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "ctxWait", nil, nil) assert.Equal(t, ErrRequestCancelled, err, "client call result") statsTags := ts.Server().StatsTags() serverStats.Expected.IncCounter("inbound.cancels.requested", statsTags, 1) serverStats.Expected.IncCounter("inbound.cancels.honored", statsTags, 1) calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "ctxWait").Failed("canceled").End() ts.AssertRelayStats(calls) }) serverStats.ValidateExpected(t) } func TestCancelWithoutSendCancelOnContextCanceled(t *testing.T) { tests := []struct { msg string sendCancelOnContextCanceled bool wantCancelRequested bool }{ { msg: "no send or process cancel", sendCancelOnContextCanceled: false, }, { msg: "only enable cancels on outbounds", sendCancelOnContextCanceled: true, wantCancelRequested: true, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { opts := testutils.NewOpts() opts.DefaultConnectionOptions.SendCancelOnContextCanceled = tt.sendCancelOnContextCanceled serverStats := newRecordingStatsReporter() opts.StatsReporter = serverStats testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { serverStats.Reset() callReceived := make(chan struct{}) testutils.RegisterFunc(ts.Server(), "ctxWait", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { require.NoError(t, ctx.Err(), "context valid before cancellation") close(callReceived) <-ctx.Done() assert.Equal(t, context.DeadlineExceeded, ctx.Err()) return nil, assert.AnError }) ctx, cancel := NewContext(testutils.Timeout(300 * time.Millisecond)) defer cancel() // Wait for the call to be recieved by the server before cancelling the context. go func() { <-callReceived cancel() }() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "ctxWait", nil, nil) assert.Equal(t, ErrRequestCancelled, err, "client call result") calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "ctxWait").Failed("timeout").End() ts.AssertRelayStats(calls) ts.AddPostFn(func() { // Validating these at the end of the test, when server has fully processed the cancellation. if tt.wantCancelRequested && !ts.HasRelay() { serverStats.Expected.IncCounter("inbound.cancels.requested", ts.Server().StatsTags(), 1) serverStats.ValidateExpected(t) } else { serverStats.EnsureNotPresent(t, "inbound.cancels.requested") } serverStats.EnsureNotPresent(t, "inbound.cancels.honored") }) }) }) } } func TestLargeMethod(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() largeMethod := testutils.RandBytes(16*1024 + 1) _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), string(largeMethod), nil, nil) assert.Equal(t, ErrMethodTooLarge, err) }) } func TestLargeTimeout(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(newTestHandler(t)), "echo") ctx, cancel := NewContext(1000 * time.Second) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "echo", testArg2, testArg3) assert.NoError(t, err, "Call failed") calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "echo").Succeeded().End() ts.AssertRelayStats(calls) }) } func TestFragmentation(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(newTestHandler(t)), "echo") arg2 := make([]byte, MaxFramePayloadSize*2) for i := 0; i < len(arg2); i++ { arg2[i] = byte('a' + (i % 10)) } arg3 := make([]byte, MaxFramePayloadSize*3) for i := 0; i < len(arg3); i++ { arg3[i] = byte('A' + (i % 10)) } ctx, cancel := NewContext(time.Second) defer cancel() respArg2, respArg3, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "echo", arg2, arg3) require.NoError(t, err) assert.Equal(t, arg2, respArg2) assert.Equal(t, arg3, respArg3) calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "echo").Succeeded().End() ts.AssertRelayStats(calls) }) } func TestFragmentationSlowReader(t *testing.T) { // Inbound forward will timeout and cause a warning log. opts := testutils.NewOpts(). AddLogFilter("Unable to forward frame", 1). AddLogFilter("Connection error", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { startReading, handlerComplete := make(chan struct{}), make(chan struct{}) handler := func(ctx context.Context, call *InboundCall) { <-startReading <-ctx.Done() _, err := raw.ReadArgs(call) assert.Error(t, err, "ReadArgs should fail since frames will be dropped due to slow reading") close(handlerComplete) } ts.Register(HandlerFunc(handler), "echo") arg2 := testutils.RandBytes(MaxFramePayloadSize * MexChannelBufferSize) arg3 := testutils.RandBytes(MaxFramePayloadSize * (MexChannelBufferSize + 1)) ctx, cancel := NewContext(testutils.Timeout(30 * time.Millisecond)) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "echo", arg2, arg3) assert.Error(t, err, "Call should timeout due to slow reader") close(startReading) select { case <-handlerComplete: case <-time.After(testutils.Timeout(70 * time.Millisecond)): t.Errorf("Handler not called, context timeout may be too low") } calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "echo").Failed("timeout").End() ts.AssertRelayStats(calls) }) } func TestWriteArg3AfterTimeout(t *testing.T) { // TODO: Debug why this is flaky in github if os.Getenv("GITHUB_WORKFLOW") != "" { t.Skip("skipping test flaky in github actions.") } // The channel reads and writes during timeouts, causing warning logs. opts := testutils.NewOpts().DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { timedOut := make(chan struct{}) handler := func(ctx context.Context, call *InboundCall) { _, err := raw.ReadArgs(call) assert.NoError(t, err, "Read args failed") response := call.Response() assert.NoError(t, NewArgWriter(response.Arg2Writer()).Write(nil), "Write Arg2 failed") writer, err := response.Arg3Writer() assert.NoError(t, err, "Arg3Writer failed") for { if _, err := writer.Write(testutils.RandBytes(4)); err != nil { assert.Equal(t, err, ErrTimeout, "Handler should timeout") close(timedOut) return } runtime.Gosched() } } ts.Register(HandlerFunc(handler), "call") ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "call", nil, nil) assert.Equal(t, err, ErrTimeout, "Call should timeout") // Wait for the write to complete, make sure there are no errors. select { case <-time.After(testutils.Timeout(300 * time.Millisecond)): t.Errorf("Handler should have failed due to timeout") case <-timedOut: } calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "call").Failed("timeout").Succeeded().End() ts.AssertRelayStats(calls) }) } func TestLargeSendSystemError(t *testing.T) { largeStr := strings.Repeat("0123456789", 10000) tests := []struct { msg string err error wantErr string }{ { msg: "error message too long", err: errors.New(largeStr), wantErr: "too long", }, { msg: "max allowed error message", err: errors.New(largeStr[:math.MaxUint16-1]), wantErr: typed.ErrBufferFull.Error(), // error message is within length, but it overflows the frame. }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() opts := testutils.NewOpts().AddLogFilter("Couldn't create outbound frame.", 1) client := ts.NewClient(opts) conn, err := client.Connect(ctx, ts.HostPort()) require.NoError(t, err, "Connect failed") err = conn.SendSystemError(1, Span{}, tt.err) require.Error(t, err, "Expect err") assert.Contains(t, err.Error(), tt.wantErr, "unexpected error") }) }) } } func TestWriteErrorAfterTimeout(t *testing.T) { // TODO: Make this test block at different points (e.g. before, during read/write). testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { timedOut := make(chan struct{}) done := make(chan struct{}) handler := func(ctx context.Context, call *InboundCall) { <-ctx.Done() <-timedOut _, err := raw.ReadArgs(call) assert.Equal(t, ErrTimeout, err, "Read args should fail with timeout") response := call.Response() assert.Equal(t, ErrTimeout, response.SendSystemError(ErrServerBusy), "SendSystemError should fail") close(done) } ts.Register(HandlerFunc(handler), "call") ctx, cancel := NewContext(testutils.Timeout(30 * time.Millisecond)) defer cancel() _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "call", nil, testutils.RandBytes(100000)) assert.Equal(t, err, ErrTimeout, "Call should timeout") close(timedOut) select { case <-done: case <-time.After(time.Second): t.Errorf("Handler not called, timeout may be too low") } calls := relaytest.NewMockStats() calls.Add(ts.ServiceName(), ts.ServiceName(), "call").Failed("timeout").End() ts.AssertRelayStats(calls) }) } func TestWriteAfterConnectionError(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() // Closing network connections can lead to warnings in many places. // TODO: Relay is disabled due to https://github.com/uber/tchannel-go/issues/390 // Enabling relay causes the test to be flaky. opts := testutils.NewOpts().DisableLogVerification().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) server := ts.Server() call, err := server.BeginCall(ctx, ts.HostPort(), server.ServiceName(), "echo", nil) require.NoError(t, err, "Call failed") w, err := call.Arg2Writer() require.NoError(t, err, "Arg2Writer failed") require.NoError(t, writeFlushStr(w, "initial"), "write initial failed") // Now close the underlying network connection, writes should fail. _, conn := OutboundConnection(call) conn.Close() // Writes should start failing pretty soon. var writeErr error for i := 0; i < 100; i++ { if writeErr = writeFlushStr(w, "f"); writeErr != nil { break } time.Sleep(time.Millisecond) } if assert.Error(t, writeErr, "Writes should fail after a connection is closed") { assert.Equal(t, ErrCodeNetwork, GetSystemErrorCode(writeErr), "write should fail due to network error") } }) } func TestReadTimeout(t *testing.T) { // The error frame may fail to send since the connection closes before the handler sends it // or the handler connection may be closed as it sends when the other side closes the conn. opts := testutils.NewOpts(). AddLogFilter("Couldn't send outbound error frame", 1). AddLogFilter("Connection error", 1, "site", "read frames"). AddLogFilter("Connection error", 1, "site", "write frames"). AddLogFilter("simpleHandler OnError", 1, "error", "failed to send error frame, connection state connectionClosed") testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { sn := ts.ServiceName() calls := relaytest.NewMockStats() for i := 0; i < 10; i++ { ctx, cancel := NewContext(time.Second) handler := func(ctx context.Context, args *raw.Args) (*raw.Res, error) { defer cancel() return nil, ErrRequestCancelled } ts.RegisterFunc("call", handler) _, _, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.ServiceName(), "call", nil, nil) assert.Equal(t, err, ErrRequestCancelled, "Call should fail due to cancel") calls.Add(sn, sn, "call").Failed("cancelled").End() } ts.AssertRelayStats(calls) }) } func TestWriteTimeout(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ch := ts.Server() ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() call, err := ch.BeginCall(ctx, ts.HostPort(), ch.ServiceName(), "call", nil) require.NoError(t, err, "Call failed") writer, err := call.Arg2Writer() require.NoError(t, err, "Arg2Writer failed") _, err = writer.Write([]byte{1}) require.NoError(t, err, "Write initial bytes failed") <-ctx.Done() _, err = io.Copy(writer, testreader.Looper([]byte{1})) assert.Equal(t, ErrTimeout, err, "Write should fail with timeout") ts.AssertRelayStats(relaytest.NewMockStats()) }) } func TestGracefulClose(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ch2 := ts.NewServer(nil) hp2 := ch2.PeerInfo().HostPort defer ch2.Close() ctx, cancel := NewContext(time.Second) defer cancel() assert.NoError(t, ts.Server().Ping(ctx, hp2), "Ping from ch1 -> ch2 failed") assert.NoError(t, ch2.Ping(ctx, ts.HostPort()), "Ping from ch2 -> ch1 failed") // No stats for pings. ts.AssertRelayStats(relaytest.NewMockStats()) }) } func TestNetDialTimeout(t *testing.T) { // timeoutHostPort uses a blackholed address (RFC 6890) with a port // reserved for documentation. This address should always cause a timeout. const timeoutHostPort = "192.18.0.254:44444" timeoutPeriod := testutils.Timeout(50 * time.Millisecond) client := testutils.NewClient(t, nil) defer client.Close() started := time.Now() ctx, cancel := NewContext(timeoutPeriod) defer cancel() err := client.Ping(ctx, timeoutHostPort) if !assert.Error(t, err, "Ping to blackhole address should fail") { return } if strings.Contains(err.Error(), "network is unreachable") { t.Skipf("Skipping test, as network interface may not be available") } d := time.Since(started) assert.Equal(t, ErrTimeout, err, "Ping expected to fail with timeout") assert.True(t, d >= timeoutPeriod, "Timeout should take more than %v, took %v", timeoutPeriod, d) } func TestConnectTimeout(t *testing.T) { opts := testutils.NewOpts().DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // Set up a relay that will delay the initial init req. testComplete := make(chan struct{}) relayFunc := func(outgoing bool, f *Frame) *Frame { select { case <-time.After(testutils.Timeout(200 * time.Millisecond)): return f case <-testComplete: // TODO: We should be able to forward the frame and have this test not fail. // Currently, it fails since the sequence of events is: // Server receives a TCP connection // Channel.Close() is called on the server // Server's TCP connection receives an init req // Since we don't currently track pending connections, the open TCP connection is not closed, and // we process the init req. This leaves an open connection at the end of the test. return nil } } relay, shutdown := testutils.FrameRelay(t, ts.HostPort(), relayFunc) defer shutdown() // Make a call with a long timeout, but short connect timeout. // We expect the call to fall almost immediately with ErrTimeout. ctx, cancel := NewContextBuilder(2 * time.Second). SetConnectTimeout(testutils.Timeout(100 * time.Millisecond)). Build() defer cancel() client := ts.NewClient(opts) err := client.Ping(ctx, relay) assert.Equal(t, ErrTimeout, err, "Ping should timeout due to timeout relay") // Note: we do not defer this, as we need to close(testComplete) before // we call shutdown since shutdown waits for the relay to close, which // is stuck waiting inside of our custom relay function. close(testComplete) }) } func TestParallelConnectionAccepts(t *testing.T) { opts := testutils.NewOpts().AddLogFilter("Failed during connection handshake", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) // Start a connection attempt that should timeout. conn, err := net.Dial("tcp", ts.HostPort()) defer conn.Close() require.NoError(t, err, "Dial failed") // When we try to make a call using a new client, it will require a // new connection, and this verifies that the previous connection attempt // and handshake do not impact the call. client := ts.NewClient(nil) testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) }) } func TestConnectionIDs(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { var inbound, outbound []uint32 relayFunc := func(outgoing bool, f *Frame) *Frame { if outgoing { outbound = append(outbound, f.Header.ID) } else { inbound = append(inbound, f.Header.ID) } return f } relay, shutdown := testutils.FrameRelay(t, ts.HostPort(), relayFunc) defer shutdown() ctx, cancel := NewContext(time.Second) defer cancel() s2 := ts.NewServer(nil) require.NoError(t, s2.Ping(ctx, relay), "Ping failed") assert.Equal(t, []uint32{1, 2}, outbound, "Unexpected outbound IDs") assert.Equal(t, []uint32{1, 2}, inbound, "Unexpected outbound IDs") // We want to reuse the same connection for the rest of the test which // only makes sense when the relay is not used. if ts.HasRelay() { return } inbound = nil outbound = nil // We will reuse the inbound connection, but since the inbound connection // hasn't originated any outbound requests, we'll use id 1. require.NoError(t, ts.Server().Ping(ctx, s2.PeerInfo().HostPort), "Ping failed") assert.Equal(t, []uint32{1}, outbound, "Unexpected outbound IDs") assert.Equal(t, []uint32{1}, inbound, "Unexpected outbound IDs") }) } func TestTosPriority(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() opts := testutils.NewOpts().SetServiceName("s1").SetTosPriority(tos.Lowdelay) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(newTestHandler(t)), "echo") outbound, err := ts.Server().BeginCall(ctx, ts.HostPort(), "s1", "echo", nil) require.NoError(t, err, "BeginCall failed") _, outboundNetConn := OutboundConnection(outbound) connTosPriority, err := isTosPriority(outboundNetConn, tos.Lowdelay) require.NoError(t, err, "Checking TOS priority failed") assert.Equal(t, connTosPriority, true) _, _, _, err = raw.WriteArgs(outbound, []byte("arg2"), []byte("arg3")) require.NoError(t, err, "Failed to write to outbound conn") }) } func TestPeerStatusChangeClientReduction(t *testing.T) { sopts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) { server := ts.Server() testutils.RegisterEcho(server, nil) changes := make(chan int, 2) copts := testutils.NewOpts().SetOnPeerStatusChanged(func(p *Peer) { i, o := p.NumConnections() assert.Equal(t, 0, i, "no inbound connections to client") changes <- o }) // Induce the creation of a connection from client to server. client := ts.NewClient(copts) require.NoError(t, testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil)) assert.Equal(t, 1, <-changes, "event for first connection") // Re-use testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) // Induce the destruction of a connection from the server to the client. server.Close() assert.Equal(t, 0, <-changes, "event for second disconnection") client.Close() assert.Len(t, changes, 0, "unexpected peer status changes") }) } func TestPeerStatusChangeClient(t *testing.T) { sopts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) { server := ts.Server() testutils.RegisterEcho(server, nil) changes := make(chan int, 2) copts := testutils.NewOpts().SetOnPeerStatusChanged(func(p *Peer) { i, o := p.NumConnections() assert.Equal(t, 0, i, "no inbound connections to client") changes <- o }) // Induce the creation of a connection from client to server. client := ts.NewClient(copts) require.NoError(t, testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil)) assert.Equal(t, 1, <-changes, "event for first connection") // Re-use testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) // Induce the creation of a second connection from client to server. pl := client.RootPeers() p := pl.GetOrAdd(ts.HostPort()) ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, testutils.Timeout(100*time.Millisecond)) defer cancel() _, err := p.Connect(ctx) require.NoError(t, err) assert.Equal(t, 2, <-changes, "event for second connection") // Induce the destruction of a connection from the server to the client. server.Close() <-changes // May be 1 or 0 depending on timing. assert.Equal(t, 0, <-changes, "event for second disconnection") client.Close() assert.Len(t, changes, 0, "unexpected peer status changes") }) } func TestPeerStatusChangeServer(t *testing.T) { changes := make(chan int, 10) sopts := testutils.NewOpts().NoRelay().SetOnPeerStatusChanged(func(p *Peer) { i, o := p.NumConnections() assert.Equal(t, 0, o, "no outbound connections from server") changes <- i }) testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) { server := ts.Server() testutils.RegisterEcho(server, nil) copts := testutils.NewOpts() for i := 0; i < 5; i++ { client := ts.NewClient(copts) // Open testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, <-changes, "one event on new connection") // Re-use testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Len(t, changes, 0, "no new events on re-used connection") // Close client.Close() assert.Equal(t, 0, <-changes, "one event on lost connection") } }) assert.Len(t, changes, 0, "unexpected peer status changes") } func TestContextCanceledOnTCPClose(t *testing.T) { // 1. Context canceled warning is expected as part of this test // add log filter to ignore this error // 2. We use our own relay in this test, so disable the relay // that comes with the test server opts := testutils.NewOpts().NoRelay().AddLogFilter("simpleHandler OnError", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { serverDoneC := make(chan struct{}) callForwarded := make(chan struct{}) ts.RegisterFunc("test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { defer close(serverDoneC) close(callForwarded) <-ctx.Done() assert.EqualError(t, ctx.Err(), "context canceled") return &raw.Res{}, nil }) // Set up a relay that can be used to terminate conns // on both sides i.e. client and server relayFunc := func(outgoing bool, f *Frame) *Frame { return f } relayHostPort, shutdown := testutils.FrameRelay(t, ts.HostPort(), relayFunc) // Make a call with a long timeout. We shutdown the relay // immediately after the server receives the call. Expected // behavior is for both client/server to be done with the call // immediately after relay shutsdown ctx, cancel := NewContext(20 * time.Second) defer cancel() clientCh := ts.NewClient(nil) // initiate the call in a background routine and // make it wait for the response clientDoneC := make(chan struct{}) go func() { raw.Call(ctx, clientCh, relayHostPort, ts.ServiceName(), "test", nil, nil) close(clientDoneC) }() // wait for server to receive the call select { case <-callForwarded: case <-time.After(2 * time.Second): assert.Fail(t, "timed waiting for call to be forwarded") } // now shutdown the relay to close conns // on both sides shutdown() // wait for both the client & server to be done select { case <-serverDoneC: case <-time.After(2 * time.Second): assert.Fail(t, "timed out waiting for server handler to exit") } select { case <-clientDoneC: case <-time.After(2 * time.Second): assert.Fail(t, "timed out waiting for client to exit") } clientCh.Close() }) } // getConnection returns the introspection result for the unique inbound or // outbound connection. An assert will be raised if there is more than one // connection of the given type. func getConnection(t testing.TB, ch *Channel, direction int) *ConnectionRuntimeState { state := ch.IntrospectState(nil) for _, peer := range state.RootPeers { var connections []ConnectionRuntimeState if direction == inbound { connections = peer.InboundConnections } else { connections = peer.OutboundConnections } assert.True(t, len(connections) <= 1, "Too many connections found: %+v", connections) if len(connections) == 1 { return &connections[0] } } assert.FailNow(t, "No connections found") return nil } func TestLastActivityTime(t *testing.T) { initialTime := time.Date(2017, 11, 27, 21, 0, 0, 0, time.UTC) clock := testutils.NewStubClock(initialTime) opts := testutils.NewOpts().SetTimeNow(clock.Now) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := ts.NewClient(opts) server := ts.Server() // Channels for synchronization. callReceived := make(chan struct{}) blockResponse := make(chan struct{}) responseReceived := make(chan struct{}) // Helper function that checks the last activity time on client, server and relay. validateLastActivityTime := func(expectedReq time.Time, expectedResp time.Time) { clientConn := getConnection(t, client, outbound) serverConn := getConnection(t, server, inbound) reqTime := expectedReq.UnixNano() respTime := expectedResp.UnixNano() assert.Equal(t, reqTime, clientConn.LastActivityWrite) assert.Equal(t, reqTime, serverConn.LastActivityRead) assert.Equal(t, respTime, clientConn.LastActivityRead) assert.Equal(t, respTime, serverConn.LastActivityWrite) // Relays should act like both clients and servers. if ts.HasRelay() { relayInbound := getConnection(t, ts.Relay(), inbound) relayOutbound := getConnection(t, ts.Relay(), outbound) assert.Equal(t, reqTime, relayInbound.LastActivityRead) assert.Equal(t, reqTime, relayOutbound.LastActivityWrite) assert.Equal(t, respTime, relayInbound.LastActivityWrite) assert.Equal(t, respTime, relayOutbound.LastActivityRead) } } // The 'echo' handler emulates a process that takes 1 second to complete. testutils.RegisterEcho(server, func() { callReceived <- struct{}{} <-blockResponse // Increment the time and return a response. clock.Elapse(1 * time.Second) }) initTime := clock.Now() // Run the test twice, because the first call will also establish a connection. for i := 0; i < 2; i++ { beforeCallSent := clock.Now() go func() { require.NoError(t, testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil)) responseReceived <- struct{}{} }() // Verify that the last activity time was updated before a response is received. <-callReceived validateLastActivityTime(beforeCallSent, initTime) // Let the server respond. blockResponse <- struct{}{} // After a response was received, time of the response should be +1s, // without a change to the requet time. Validate again that the last // activity time was updated. <-responseReceived validateLastActivityTime(beforeCallSent, beforeCallSent.Add(1*time.Second)) // Set the initTime as the time of the last response. initTime = beforeCallSent.Add(1 * time.Second) // Elapse the clock for our next iteration. clock.Elapse(1 * time.Minute) } close(responseReceived) close(blockResponse) close(callReceived) }) } func TestLastActivityTimePings(t *testing.T) { initialTime := time.Date(2017, 11, 27, 21, 0, 0, 0, time.UTC) clock := testutils.NewStubClock(initialTime) opts := testutils.NewOpts().SetTimeNow(clock.Now) ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := ts.NewClient(opts) // Send an 'echo' to establish the connection. testutils.RegisterEcho(ts.Server(), nil) require.NoError(t, testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil)) timeAtStart := clock.Now().UnixNano() for i := 0; i < 2; i++ { require.NoError(t, client.Ping(ctx, ts.HostPort())) // Verify last activity time. clientConn := getConnection(t, client, outbound) assert.Equal(t, timeAtStart, clientConn.LastActivityRead) assert.Equal(t, timeAtStart, clientConn.LastActivityWrite) // Relays do not pass pings on to the server. if ts.HasRelay() { relayInbound := getConnection(t, ts.Relay(), inbound) assert.Equal(t, timeAtStart, relayInbound.LastActivityRead) assert.Equal(t, timeAtStart, relayInbound.LastActivityWrite) } serverConn := getConnection(t, ts.Server(), inbound) assert.Equal(t, timeAtStart, serverConn.LastActivityRead) assert.Equal(t, timeAtStart, serverConn.LastActivityWrite) clock.Elapse(1 * time.Second) } }) } func TestSendBufferSize(t *testing.T) { opts := testutils.NewOpts().SetSendBufferSize(512).SetSendBufferSizeOverrides([]SendBufferSizeOverride{ {"abc", 1024}, {"abcd", 2048}, // This should never match, since we match the list in order. {"xyz", 3072}, }) tests := []struct { processName string expectSendChCapacity int }{ { processName: "abc", expectSendChCapacity: 1024, }, { processName: "abcd", expectSendChCapacity: 1024, }, { processName: "bcd", expectSendChCapacity: DefaultConnectionBufferSize, }, { processName: "dabc", expectSendChCapacity: DefaultConnectionBufferSize, }, { processName: "dabcd", expectSendChCapacity: DefaultConnectionBufferSize, }, { processName: "abcde", expectSendChCapacity: 1024, }, { processName: "xyzabc", expectSendChCapacity: 3072, }, } for _, tt := range tests { t.Run(tt.processName, func(t *testing.T) { testutils.WithTestServer(t, opts, func(tb testing.TB, ts *testutils.TestServer) { client := ts.NewClient(opts.SetProcessName(tt.processName)) // Send an 'echo' to establish the connection. testutils.RegisterEcho(ts.Server(), nil) require.NoError(t, testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil)) // WithTestSever will test with and without relay. if ts.HasRelay() { assert.Equal(t, tt.expectSendChCapacity, getConnection(t, ts.Relay(), inbound).SendChCapacity) } else { assert.Equal(t, tt.expectSendChCapacity, getConnection(t, ts.Server(), inbound).SendChCapacity) } }) }) } } func TestInvalidTransportHeaders(t *testing.T) { long100 := strings.Repeat("0123456789", 10) long300 := strings.Repeat("0123456789", 30) tests := []struct { msg string ctxFn func(*ContextBuilder) svcOverride string wantErr string }{ { msg: "valid long fields", ctxFn: func(cb *ContextBuilder) { cb.SetRoutingKey(long100) cb.SetShardKey(long100) }, }, { msg: "long routing key", ctxFn: func(cb *ContextBuilder) { cb.SetRoutingKey(long300) }, wantErr: "too long", }, { msg: "long shard key", ctxFn: func(cb *ContextBuilder) { cb.SetShardKey(long300) }, wantErr: "too long", }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(nil) cb := NewContextBuilder(time.Second) tt.ctxFn(cb) ctx, cancel := cb.Build() defer cancel() svc := ts.ServiceName() if tt.svcOverride != "" { svc = tt.svcOverride } _, _, _, err := raw.Call(ctx, client, ts.HostPort(), svc, "echo", nil, nil) if tt.wantErr == "" { require.NoError(t, err, "unexpected error") return } require.Error(t, err) assert.Contains(t, err.Error(), tt.wantErr, "unexpected error") }) }) } } func TestCustomDialer(t *testing.T) { sopts := testutils.NewOpts() testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) { server := ts.Server() testutils.RegisterEcho(server, nil) customDialerCalledCount := 0 copts := testutils.NewOpts().SetDialer(func(ctx context.Context, network, hostPort string) (net.Conn, error) { customDialerCalledCount++ d := net.Dialer{} return d.DialContext(ctx, network, hostPort) }) // Induce the creation of a connection from client to server. client := ts.NewClient(copts) testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection") // Re-use testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection") }) } func TestInboundConnContext(t *testing.T) { opts := testutils.NewOpts().NoRelay().SetConnContext(func(ctx context.Context, conn net.Conn) context.Context { return context.WithValue(ctx, "foo", "bar") }) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { alice := ts.Server() testutils.RegisterFunc(alice, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { // Verify that the context passed into the handler inherits from the base context // set by ConnContext assert.Equal(t, "bar", ctx.Value("foo"), "Value unexpectedly different from base context") return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil }) copts := testutils.NewOpts() bob := ts.NewClient(copts) testutils.AssertEcho(t, bob, ts.HostPort(), ts.ServiceName()) }) } func TestOutboundConnContext(t *testing.T) { opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { alice := ts.Server() testutils.RegisterFunc(alice, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { assert.Equal(t, "bar", ctx.Value("foo"), "Base context key unexpectedly absent") return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil }) bobOpts := testutils.NewOpts().SetServiceName("bob") bob := ts.NewServer(bobOpts) testutils.RegisterEcho(bob, nil) baseCtx := context.WithValue(context.Background(), "foo", "bar") ctx, cancel := NewContextBuilder(time.Second).SetConnectBaseContext(baseCtx).Build() defer cancel() err := alice.Ping(ctx, bob.PeerInfo().HostPort) require.NoError(t, err) testutils.AssertEcho(t, bob, ts.HostPort(), ts.ServiceName()) }) } func TestWithTLSNoRelay(t *testing.T) { // NOTE: "Connection does not implement SyscallConn." logs are filtered as tls.Conn doesn't implement syscall.Conn. sopts := testutils.NewOpts().SetServeTLS(true).NoRelay(). AddLogFilter("Connection does not implement SyscallConn.", 1) testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) { server := ts.Server() testutils.RegisterEcho(server, nil) customDialerCalledCount := 0 copts := testutils.NewOpts().SetDialer(func(ctx context.Context, network, hostPort string) (net.Conn, error) { customDialerCalledCount++ d := tls.Dialer{ Config: &tls.Config{InsecureSkipVerify: true}, } return d.DialContext(ctx, network, hostPort) }).AddLogFilter("Connection does not implement SyscallConn.", 1) // Induce the creation of a connection from client to server. client := ts.NewClient(copts) testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection") // Re-use testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection") }) } func TestWithTLSRelayOnly(t *testing.T) { // NOTE: "Connection does not implement SyscallConn." logs are filtered as tls.Conn doesn't implement syscall.Conn. // SetDialer with tls.Dial as relay uses dialer from server opts to make outbound connections. sopts := testutils.NewOpts().SetServeTLS(true).SetRelayOnly().SetDialer(func(ctx context.Context, network, hostPort string) (net.Conn, error) { d := tls.Dialer{ Config: &tls.Config{InsecureSkipVerify: true}, } return d.DialContext(ctx, network, hostPort) }).AddLogFilter("Connection does not implement SyscallConn.", 2) // 1 + 1 for server & relay testutils.WithTestServer(t, sopts, func(t testing.TB, ts *testutils.TestServer) { server := ts.Server() testutils.RegisterEcho(server, nil) customDialerCalledCount := 0 copts := testutils.NewOpts().SetDialer(func(ctx context.Context, network, hostPort string) (net.Conn, error) { customDialerCalledCount++ d := tls.Dialer{ Config: &tls.Config{InsecureSkipVerify: true}, } return d.DialContext(ctx, network, hostPort) }).AddLogFilter("Connection does not implement SyscallConn.", 1) // Induce the creation of a connection from client to server. client := ts.NewClient(copts) testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection") // Re-use testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) assert.Equal(t, 1, customDialerCalledCount, "custom dialer used for establishing connection") }) } ================================================ FILE: connectionstate_string.go ================================================ // Code generated by "stringer -type=connectionState"; DO NOT EDIT package tchannel import "fmt" const _connectionState_name = "connectionActiveconnectionStartCloseconnectionInboundClosedconnectionClosed" var _connectionState_index = [...]uint8{0, 16, 36, 59, 75} func (i connectionState) String() string { i -= 1 if i < 0 || i >= connectionState(len(_connectionState_index)-1) { return fmt.Sprintf("connectionState(%d)", i+1) } return _connectionState_name[_connectionState_index[i]:_connectionState_index[i+1]] } ================================================ FILE: context.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "time" "golang.org/x/net/context" ) const defaultTimeout = time.Second type contextKey int const ( contextKeyTChannel contextKey = iota contextKeyHeaders ) type tchannelCtxParams struct { tracingDisabled bool hideListeningOnOutbound bool call IncomingCall options *CallOptions retryOptions *RetryOptions connectTimeout time.Duration connectBaseContext context.Context } // IncomingCall exposes properties for incoming calls through the context. type IncomingCall interface { // CallerName returns the caller name from the CallerName transport header. CallerName() string // ShardKey returns the shard key from the ShardKey transport header. ShardKey() string // RoutingKey returns the routing key (referring to a traffic group) from // RoutingKey transport header. RoutingKey() string // RoutingDelegate returns the routing delegate from RoutingDelegate // transport header. RoutingDelegate() string // LocalPeer returns the local peer information. LocalPeer() LocalPeerInfo // RemotePeer returns the caller's peer information. // If the caller is an ephemeral peer, then the HostPort cannot be used to make new // connections to the caller. RemotePeer() PeerInfo // CallOptions returns the call options set for the incoming call. It can be // useful for forwarding requests. CallOptions() *CallOptions } func getTChannelParams(ctx context.Context) *tchannelCtxParams { if params, ok := ctx.Value(contextKeyTChannel).(*tchannelCtxParams); ok { return params } return nil } // NewContext returns a new root context used to make TChannel requests. func NewContext(timeout time.Duration) (context.Context, context.CancelFunc) { return NewContextBuilder(timeout).Build() } // WrapContextForTest returns a copy of the given Context that is associated with the call. // This should be used in units test only. // NOTE: This method is deprecated. Callers should use NewContextBuilder().SetIncomingCallForTest. func WrapContextForTest(ctx context.Context, call IncomingCall) context.Context { getTChannelParams(ctx).call = call return ctx } // newIncomingContext creates a new context for an incoming call with the given span. func newIncomingContext(ctx context.Context, call IncomingCall, timeout time.Duration) (context.Context, context.CancelFunc) { return NewContextBuilder(timeout). SetParentContext(ctx). setIncomingCall(call). Build() } // CurrentCall returns the current incoming call, or nil if this is not an incoming call context. func CurrentCall(ctx context.Context) IncomingCall { if params := getTChannelParams(ctx); params != nil { return params.call } return nil } func currentCallOptions(ctx context.Context) *CallOptions { if params := getTChannelParams(ctx); params != nil { return params.options } return nil } func isTracingDisabled(ctx context.Context) bool { if params := getTChannelParams(ctx); params != nil { return params.tracingDisabled } return false } ================================================ FILE: context_builder.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "time" "golang.org/x/net/context" ) // ContextBuilder stores all TChannel-specific parameters that will // be stored inside of a context. type ContextBuilder struct { // TracingDisabled disables trace reporting for calls using this context. TracingDisabled bool // hideListeningOnOutbound disables sending the listening server's host:port // when creating new outgoing connections. hideListeningOnOutbound bool // replaceParentHeaders is set to true when SetHeaders() method is called. // It forces headers from ParentContext to be ignored. When false, parent // headers will be merged with headers accumulated by the builder. replaceParentHeaders bool // If Timeout is zero, Build will default to defaultTimeout. Timeout time.Duration // Headers are application headers that json/thrift will encode into arg2. Headers map[string]string // CallOptions are TChannel call options for the specific call. CallOptions *CallOptions // RetryOptions are the retry options for this call. RetryOptions *RetryOptions // ConnectTimeout is the timeout for creating a TChannel connection. ConnectTimeout time.Duration // ConnectBaseContext is the base context for all connections ConnectBaseContext context.Context // ParentContext to build the new context from. If empty, context.Background() is used. // The new (child) context inherits a number of properties from the parent context: // - context fields, accessible via `ctx.Value(key)` // - headers if parent is a ContextWithHeaders, unless replaced via SetHeaders() ParentContext context.Context // Hidden fields: we do not want users outside of tchannel to set these. incomingCall IncomingCall } // NewContextBuilder returns a builder that can be used to create a Context. func NewContextBuilder(timeout time.Duration) *ContextBuilder { return &ContextBuilder{ Timeout: timeout, } } // SetTimeout sets the timeout for the Context. func (cb *ContextBuilder) SetTimeout(timeout time.Duration) *ContextBuilder { cb.Timeout = timeout return cb } // AddHeader adds a single application header to the Context. func (cb *ContextBuilder) AddHeader(key, value string) *ContextBuilder { if cb.Headers == nil { cb.Headers = map[string]string{key: value} } else { cb.Headers[key] = value } return cb } // SetHeaders sets the application headers for this Context. // If there is a ParentContext, its headers will be ignored after the call to this method. func (cb *ContextBuilder) SetHeaders(headers map[string]string) *ContextBuilder { cb.Headers = headers cb.replaceParentHeaders = true return cb } // SetShardKey sets the ShardKey call option ("sk" transport header). func (cb *ContextBuilder) SetShardKey(sk string) *ContextBuilder { if cb.CallOptions == nil { cb.CallOptions = new(CallOptions) } cb.CallOptions.ShardKey = sk return cb } // SetFormat sets the Format call option ("as" transport header). func (cb *ContextBuilder) SetFormat(f Format) *ContextBuilder { if cb.CallOptions == nil { cb.CallOptions = new(CallOptions) } cb.CallOptions.Format = f return cb } // SetRoutingKey sets the RoutingKey call options ("rk" transport header). func (cb *ContextBuilder) SetRoutingKey(rk string) *ContextBuilder { if cb.CallOptions == nil { cb.CallOptions = new(CallOptions) } cb.CallOptions.RoutingKey = rk return cb } // SetRoutingDelegate sets the RoutingDelegate call options ("rd" transport header). func (cb *ContextBuilder) SetRoutingDelegate(rd string) *ContextBuilder { if cb.CallOptions == nil { cb.CallOptions = new(CallOptions) } cb.CallOptions.RoutingDelegate = rd return cb } // SetConnectTimeout sets the ConnectionTimeout for this context. // The context timeout applies to the whole call, while the connect // timeout only applies to creating a new connection. func (cb *ContextBuilder) SetConnectTimeout(d time.Duration) *ContextBuilder { cb.ConnectTimeout = d return cb } // SetConnectBaseContext sets the base context for any outbound connection created func (cb *ContextBuilder) SetConnectBaseContext(ctx context.Context) *ContextBuilder { cb.ConnectBaseContext = ctx return cb } // HideListeningOnOutbound hides the host:port when creating new outbound // connections. func (cb *ContextBuilder) HideListeningOnOutbound() *ContextBuilder { cb.hideListeningOnOutbound = true return cb } // DisableTracing disables tracing. func (cb *ContextBuilder) DisableTracing() *ContextBuilder { cb.TracingDisabled = true return cb } // SetIncomingCallForTest sets an IncomingCall in the context. // This should only be used in unit tests. func (cb *ContextBuilder) SetIncomingCallForTest(call IncomingCall) *ContextBuilder { return cb.setIncomingCall(call) } // SetRetryOptions sets RetryOptions in the context. func (cb *ContextBuilder) SetRetryOptions(retryOptions *RetryOptions) *ContextBuilder { cb.RetryOptions = retryOptions return cb } // SetTimeoutPerAttempt sets TimeoutPerAttempt in RetryOptions. func (cb *ContextBuilder) SetTimeoutPerAttempt(timeoutPerAttempt time.Duration) *ContextBuilder { if cb.RetryOptions == nil { cb.RetryOptions = &RetryOptions{} } cb.RetryOptions.TimeoutPerAttempt = timeoutPerAttempt return cb } // SetParentContext sets the parent for the Context. func (cb *ContextBuilder) SetParentContext(ctx context.Context) *ContextBuilder { cb.ParentContext = ctx return cb } func (cb *ContextBuilder) setIncomingCall(call IncomingCall) *ContextBuilder { cb.incomingCall = call return cb } func (cb *ContextBuilder) getHeaders() map[string]string { if cb.ParentContext == nil || cb.replaceParentHeaders { return cb.Headers } parent, ok := cb.ParentContext.Value(contextKeyHeaders).(*headersContainer) if !ok || len(parent.reqHeaders) == 0 { return cb.Headers } mergedHeaders := make(map[string]string, len(cb.Headers)+len(parent.reqHeaders)) for k, v := range parent.reqHeaders { mergedHeaders[k] = v } for k, v := range cb.Headers { mergedHeaders[k] = v } return mergedHeaders } // Build returns a ContextWithHeaders that can be used to make calls. func (cb *ContextBuilder) Build() (ContextWithHeaders, context.CancelFunc) { params := &tchannelCtxParams{ options: cb.CallOptions, call: cb.incomingCall, retryOptions: cb.RetryOptions, connectTimeout: cb.ConnectTimeout, hideListeningOnOutbound: cb.hideListeningOnOutbound, tracingDisabled: cb.TracingDisabled, connectBaseContext: cb.ConnectBaseContext, } parent := cb.ParentContext if parent == nil { parent = context.Background() } else if headerCtx, ok := parent.(headerCtx); ok { // Unwrap any headerCtx, since we'll be rewrapping anyway. parent = headerCtx.Context } var ( ctx context.Context cancel context.CancelFunc ) // All contexts created must have a timeout, but if the parent // already has a timeout, and the user has not specified one, then we // can use context.WithCancel _, parentHasDeadline := parent.Deadline() if cb.Timeout == 0 && parentHasDeadline { ctx, cancel = context.WithCancel(parent) } else { ctx, cancel = context.WithTimeout(parent, cb.Timeout) } ctx = context.WithValue(ctx, contextKeyTChannel, params) return WrapWithHeaders(ctx, cb.getHeaders()), cancel } ================================================ FILE: context_header.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "golang.org/x/net/context" // ContextWithHeaders is a Context which contains request and response headers. type ContextWithHeaders interface { context.Context // Headers returns the call request headers. Headers() map[string]string // ResponseHeaders returns the call response headers. ResponseHeaders() map[string]string // SetResponseHeaders sets the given response headers on the context. SetResponseHeaders(map[string]string) // Child creates a child context which stores headers separately from // the parent context. Child() ContextWithHeaders } type headerCtx struct { context.Context } // headersContainer stores the headers, and is itself stored in the context under `contextKeyHeaders` type headersContainer struct { reqHeaders map[string]string respHeaders map[string]string } func (c headerCtx) headers() *headersContainer { if h, ok := c.Value(contextKeyHeaders).(*headersContainer); ok { return h } return nil } // Headers gets application headers out of the context. func (c headerCtx) Headers() map[string]string { if h := c.headers(); h != nil { return h.reqHeaders } return nil } // ResponseHeaders returns the response headers. func (c headerCtx) ResponseHeaders() map[string]string { if h := c.headers(); h != nil { return h.respHeaders } return nil } // SetResponseHeaders sets the response headers. func (c headerCtx) SetResponseHeaders(headers map[string]string) { if h := c.headers(); h != nil { h.respHeaders = headers return } panic("SetResponseHeaders called on ContextWithHeaders not created via WrapWithHeaders") } // Child creates a child context with a separate container for headers. func (c headerCtx) Child() ContextWithHeaders { var headersCopy headersContainer if h := c.headers(); h != nil { headersCopy = *h } return Wrap(context.WithValue(c.Context, contextKeyHeaders, &headersCopy)) } // Wrap wraps an existing context.Context into a ContextWithHeaders. // If the underlying context has headers, they are preserved. func Wrap(ctx context.Context) ContextWithHeaders { hctx := headerCtx{Context: ctx} if h := hctx.headers(); h != nil { return hctx } // If there is no header container, we should create an empty one. return WrapWithHeaders(ctx, nil) } // WrapWithHeaders returns a Context that can be used to make a call with request headers. // If the parent `ctx` is already an instance of ContextWithHeaders, its existing headers // will be ignored. In order to merge new headers with parent headers, use ContextBuilder. func WrapWithHeaders(ctx context.Context, headers map[string]string) ContextWithHeaders { h := &headersContainer{ reqHeaders: headers, } newCtx := context.WithValue(ctx, contextKeyHeaders, h) return headerCtx{Context: newCtx} } // WithoutHeaders hides any TChannel headers from the given context. func WithoutHeaders(ctx context.Context) context.Context { return context.WithValue(context.WithValue(ctx, contextKeyTChannel, nil), contextKeyHeaders, nil) } ================================================ FILE: context_internal_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "testing" "time" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) func TestNewContextBuilderDisableTracing(t *testing.T) { ctx, cancel := NewContextBuilder(time.Second). DisableTracing().Build() defer cancel() assert.True(t, isTracingDisabled(ctx), "Tracing should be disabled") } func TestCurrentSpan(t *testing.T) { ctx := context.Background() span := CurrentSpan(ctx) require.NotNil(t, span, "CurrentSpan() should always return something") tracer := mocktracer.New() sp := tracer.StartSpan("test") ctx = opentracing.ContextWithSpan(ctx, sp) span = CurrentSpan(ctx) require.NotNil(t, span, "CurrentSpan() should always return something") assert.EqualValues(t, 0, span.TraceID(), "mock tracer is not Zipkin-compatible") tracer.RegisterInjector(zipkinSpanFormat, new(zipkinInjector)) span = CurrentSpan(ctx) require.NotNil(t, span, "CurrentSpan() should always return something") assert.NotEqual(t, uint64(0), span.TraceID(), "mock tracer is now Zipkin-compatible") } func TestContextWithoutHeadersKeyHeaders(t *testing.T) { ctx := WrapWithHeaders(context.Background(), map[string]string{"k1": "v1"}) assert.Equal(t, map[string]string{"k1": "v1"}, ctx.Headers()) ctx2 := WithoutHeaders(ctx) assert.Nil(t, ctx2.Value(contextKeyHeaders)) _, ok := ctx2.(ContextWithHeaders) assert.False(t, ok) } func TestContextWithoutHeadersKeyTChannel(t *testing.T) { ctx, _ := NewContextBuilder(time.Second).SetShardKey("s1").Build() ctx2 := WithoutHeaders(ctx) assert.Nil(t, ctx2.Value(contextKeyTChannel)) _, ok := ctx2.(ContextWithHeaders) assert.False(t, ok) } ================================================ FILE: context_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/goroutines" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) var cn = "hello" func TestWrapContextForTest(t *testing.T) { call := testutils.NewIncomingCall(cn) ctx, cancel := NewContext(time.Second) defer cancel() actual := WrapContextForTest(ctx, call) assert.Equal(t, call, CurrentCall(actual), "Incorrect call object returned.") } func TestNewContextTimeoutZero(t *testing.T) { ctx, cancel := NewContextBuilder(0).Build() defer cancel() deadline, ok := ctx.Deadline() assert.True(t, ok, "Context missing deadline") assert.True(t, deadline.Sub(time.Now()) <= 0, "Deadline should be Now or earlier") } func TestRoutingDelegatePropagates(t *testing.T) { WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { peerInfo := ch.PeerInfo() testutils.RegisterFunc(ch, "test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg3: []byte(CurrentCall(ctx).RoutingDelegate()), }, nil }) ctx, cancel := NewContextBuilder(time.Second).Build() defer cancel() _, arg3, _, err := raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "test", nil, nil) assert.NoError(t, err, "Call failed") assert.Equal(t, "", string(arg3), "Expected no routing delegate header") ctx, cancel = NewContextBuilder(time.Second).SetRoutingDelegate("xpr").Build() defer cancel() _, arg3, _, err = raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "test", nil, nil) assert.NoError(t, err, "Call failed") assert.Equal(t, "xpr", string(arg3), "Expected routing delegate header to be set") }) } func TestRoutingKeyPropagates(t *testing.T) { WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { peerInfo := ch.PeerInfo() testutils.RegisterFunc(ch, "test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg3: []byte(CurrentCall(ctx).RoutingKey()), }, nil }) ctx, cancel := NewContextBuilder(time.Second).Build() defer cancel() _, arg3, _, err := raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "test", nil, nil) assert.NoError(t, err, "Call failed") assert.Equal(t, "", string(arg3), "Expected no routing key header") ctx, cancel = NewContextBuilder(time.Second).SetRoutingKey("canary").Build() defer cancel() _, arg3, _, err = raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "test", nil, nil) assert.NoError(t, err, "Call failed") assert.Equal(t, "canary", string(arg3), "Expected routing key header to be set") }) } func TestShardKeyPropagates(t *testing.T) { WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { peerInfo := ch.PeerInfo() testutils.RegisterFunc(ch, "test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg3: []byte(CurrentCall(ctx).ShardKey()), }, nil }) ctx, cancel := NewContextBuilder(time.Second).Build() defer cancel() _, arg3, _, err := raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "test", nil, nil) assert.NoError(t, err, "Call failed") assert.Equal(t, arg3, []byte("")) ctx, cancel = NewContextBuilder(time.Second). SetShardKey("shard").Build() defer cancel() _, arg3, _, err = raw.Call(ctx, ch, peerInfo.HostPort, peerInfo.ServiceName, "test", nil, nil) assert.NoError(t, err, "Call failed") assert.Equal(t, string(arg3), "shard") }) } func TestCurrentCallWithNilResult(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() call := CurrentCall(ctx) assert.Nil(t, call, "Should return nil.") } func getParentContext(t *testing.T) ContextWithHeaders { ctx := context.WithValue(context.Background(), "some key", "some value") assert.Equal(t, "some value", ctx.Value("some key")) ctx1, _ := NewContextBuilder(time.Second). SetParentContext(ctx). AddHeader("header key", "header value"). Build() assert.Equal(t, "some value", ctx1.Value("some key")) return ctx1 } func TestContextBuilderParentContextNoHeaders(t *testing.T) { ctx := getParentContext(t) assert.Equal(t, map[string]string{"header key": "header value"}, ctx.Headers()) assert.EqualValues(t, "some value", ctx.Value("some key"), "inherited from parent ctx") } func TestContextBuilderParentContextMergeHeaders(t *testing.T) { ctx := getParentContext(t) ctx.Headers()["fixed header"] = "fixed value" // append header to parent ctx2, _ := NewContextBuilder(time.Second). SetParentContext(ctx). AddHeader("header key 2", "header value 2"). Build() assert.Equal(t, map[string]string{ "header key": "header value", // inherited "fixed header": "fixed value", // inherited "header key 2": "header value 2", // appended }, ctx2.Headers()) // override parent header ctx3, _ := NewContextBuilder(time.Second). SetParentContext(ctx). AddHeader("header key", "header value 2"). // override Build() assert.Equal(t, map[string]string{ "header key": "header value 2", // overwritten "fixed header": "fixed value", // inherited }, ctx3.Headers()) goroutines.VerifyNoLeaks(t, nil) } func TestContextBuilderParentContextReplaceHeaders(t *testing.T) { ctx := getParentContext(t) ctx.Headers()["fixed header"] = "fixed value" assert.Equal(t, map[string]string{ "header key": "header value", "fixed header": "fixed value", }, ctx.Headers()) // replace headers with a new map ctx2, _ := NewContextBuilder(time.Second). SetParentContext(ctx). SetHeaders(map[string]string{"header key": "header value 2"}). Build() assert.Equal(t, map[string]string{"header key": "header value 2"}, ctx2.Headers()) goroutines.VerifyNoLeaks(t, nil) } func TestContextWrapWithHeaders(t *testing.T) { headers1 := map[string]string{ "k1": "v1", } ctx, _ := NewContextBuilder(time.Second). SetHeaders(headers1). Build() assert.Equal(t, headers1, ctx.Headers(), "Headers mismatch after Build") headers2 := map[string]string{ "k1": "v1", } ctx2 := WrapWithHeaders(ctx, headers2) assert.Equal(t, headers2, ctx2.Headers(), "Headers mismatch after WrapWithHeaders") } func TestContextWithHeadersAsContext(t *testing.T) { var ctx context.Context = getParentContext(t) assert.EqualValues(t, "some value", ctx.Value("some key"), "inherited from parent ctx") } func TestContextBuilderParentContextSpan(t *testing.T) { ctx := getParentContext(t) assert.Equal(t, "some value", ctx.Value("some key")) ctx2, _ := NewContextBuilder(time.Second). SetParentContext(ctx). Build() assert.Equal(t, "some value", ctx2.Value("some key"), "key/value propagated from parent ctx") goroutines.VerifyNoLeaks(t, nil) } func TestContextWrapChild(t *testing.T) { tests := []struct { msg string ctxFn func() ContextWithHeaders wantHeaders map[string]string wantValue interface{} }{ { msg: "Basic context", ctxFn: func() ContextWithHeaders { ctxNoHeaders, _ := NewContextBuilder(time.Second).Build() return ctxNoHeaders }, wantHeaders: nil, wantValue: nil, }, { msg: "Wrap basic context with value", ctxFn: func() ContextWithHeaders { ctxNoHeaders, _ := NewContextBuilder(time.Second).Build() return Wrap(context.WithValue(ctxNoHeaders, "1", "2")) }, wantHeaders: nil, wantValue: "2", }, { msg: "Wrap context with headers and value", ctxFn: func() ContextWithHeaders { ctxWithHeaders, _ := NewContextBuilder(time.Second).AddHeader("h1", "v1").Build() return Wrap(context.WithValue(ctxWithHeaders, "1", "2")) }, wantHeaders: map[string]string{"h1": "v1"}, wantValue: "2", }, } for _, tt := range tests { for _, child := range []bool{false, true} { origCtx := tt.ctxFn() ctx := origCtx if child { ctx = origCtx.Child() } assert.Equal(t, tt.wantValue, ctx.Value("1"), "%v: Unexpected value", tt.msg) assert.Equal(t, tt.wantHeaders, ctx.Headers(), "%v: Unexpected headers", tt.msg) respHeaders := map[string]string{"r": "v"} ctx.SetResponseHeaders(respHeaders) assert.Equal(t, respHeaders, ctx.ResponseHeaders(), "%v: Unexpected response headers", tt.msg) if child { // If we're working with a child context, changes to response headers // should not affect the original context. assert.Nil(t, origCtx.ResponseHeaders(), "%v: Child modified original context's headers", tt.msg) } } } } func TestContextInheritParentTimeout(t *testing.T) { deadlineAfter := time.Now().Add(time.Hour) pctx, cancel := context.WithTimeout(context.Background(), time.Hour) defer cancel() ctxBuilder := &ContextBuilder{ ParentContext: pctx, } ctx, cancel := ctxBuilder.Build() defer cancel() // Ensure deadline is in the future deadline, ok := ctx.Deadline() require.True(t, ok, "Missing deadline") assert.False(t, deadline.Before(deadlineAfter), "Expected deadline to be after %v, got %v", deadlineAfter, deadline) } ================================================ FILE: deps_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Our glide.yaml lists a set of directories in excludeDirs to avoid packages // only used in testing from pulling in dependencies that should not affect // package resolution for clients. // However, we really want these directories to be part of test imports. Since // glide does not provide a "testDirs" option, we add dependencies required // for tests in this _test.go file. package tchannel_test import ( "testing" jcg "github.com/uber/jaeger-client-go" ) func TestJaegerDeps(t *testing.T) { m := jcg.Metrics{} _ = m.SamplerUpdateFailure } ================================================ FILE: dial_16.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:build !go1.7 // +build !go1.7 package tchannel import ( "net" "golang.org/x/net/context" ) func dialContext(ctx context.Context, hostPort string) (net.Conn, error) { timeout := getTimeout(ctx) return net.DialTimeout("tcp", hostPort, timeout) } ================================================ FILE: dial_17.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:build go1.7 // +build go1.7 package tchannel import ( "context" "net" ) func dialContext(ctx context.Context, hostPort string) (net.Conn, error) { d := net.Dialer{} return d.DialContext(ctx, "tcp", hostPort) } ================================================ FILE: dial_17_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:build go1.7 // +build go1.7 package tchannel_test import ( "strings" "testing" "time" . "github.com/uber/tchannel-go" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go/testutils" ) func TestNetDialCancelContext(t *testing.T) { // timeoutHostPort uses a blackholed address (RFC 6890) with a port // reserved for documentation. This address should always cause a timeout. const timeoutHostPort = "192.18.0.254:44444" timeoutPeriod := testutils.Timeout(50 * time.Millisecond) client := testutils.NewClient(t, nil) defer client.Close() started := time.Now() ctx, cancel := NewContext(time.Minute) go func() { time.Sleep(timeoutPeriod) cancel() }() err := client.Ping(ctx, timeoutHostPort) if !assert.Error(t, err, "Ping to blackhole address should fail") { return } if strings.Contains(err.Error(), "network is unreachable") { t.Skipf("Skipping test, as network interface may not be available") } d := time.Since(started) assert.Equal(t, ErrCodeCancelled, GetSystemErrorCode(err), "Ping expected to fail with context cancelled") assert.True(t, d < 2*timeoutPeriod, "Timeout should take less than %v, took %v", 2*timeoutPeriod, d) } ================================================ FILE: doc.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. /* Package tchannel implements Go bindings for the TChannel protocol (https://github.com/uber/tchannel). A single Channel can be used for many concurrent requests to many hosts. */ package tchannel ================================================ FILE: errors.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "golang.org/x/net/context" ) const ( // Message id for protocol level errors invalidMessageID uint32 = 0xFFFFFFFF ) // A SystemErrCode indicates how a caller should handle a system error returned from a peer type SystemErrCode byte //go:generate stringer -type=SystemErrCode const ( // ErrCodeInvalid is an invalid error code, and should not be used ErrCodeInvalid SystemErrCode = 0x00 // ErrCodeTimeout indicates the peer timed out. Callers can retry the request // on another peer if the request is safe to retry. ErrCodeTimeout SystemErrCode = 0x01 // ErrCodeCancelled indicates that the request was cancelled on the peer. Callers // can retry the request on the same or another peer if the request is safe to retry ErrCodeCancelled SystemErrCode = 0x02 // ErrCodeBusy indicates that the request was not dispatched because the peer // was too busy to handle it. Callers can retry the request on another peer, and should // reweight their connections to direct less traffic to this peer until it recovers. ErrCodeBusy SystemErrCode = 0x03 // ErrCodeDeclined indicates that the request not dispatched because the peer // declined to handle it, typically because the peer is not yet ready to handle it. // Callers can retry the request on another peer, but should not reweight their connections // and should continue to send traffic to this peer. ErrCodeDeclined SystemErrCode = 0x04 // ErrCodeUnexpected indicates that the request failed for an unexpected reason, typically // a crash or other unexpected handling. The request may have been processed before the failure; // callers should retry the request on this or another peer only if the request is safe to retry ErrCodeUnexpected SystemErrCode = 0x05 // ErrCodeBadRequest indicates that the request was malformed, and could not be processed. // Callers should not bother to retry the request, as there is no chance it will be handled. ErrCodeBadRequest SystemErrCode = 0x06 // ErrCodeNetwork indicates a network level error, such as a connection reset. // Callers can retry the request if the request is safe to retry ErrCodeNetwork SystemErrCode = 0x07 // ErrCodeProtocol indincates a fatal protocol error communicating with the peer. The connection // will be terminated. ErrCodeProtocol SystemErrCode = 0xFF ) var ( // ErrServerBusy is a SystemError indicating the server is busy ErrServerBusy = NewSystemError(ErrCodeBusy, "server busy") // ErrRequestCancelled is a SystemError indicating the request has been cancelled on the peer ErrRequestCancelled = NewSystemError(ErrCodeCancelled, "request cancelled") // ErrTimeout is a SytemError indicating the request has timed out ErrTimeout = NewSystemError(ErrCodeTimeout, "timeout") // ErrTimeoutRequired is a SystemError indicating that timeouts must be specified. ErrTimeoutRequired = NewSystemError(ErrCodeBadRequest, "timeout required") // ErrChannelClosed is a SystemError indicating that the channel has been closed. ErrChannelClosed = NewSystemError(ErrCodeDeclined, "closed channel") // ErrMethodTooLarge is a SystemError indicating that the method is too large. ErrMethodTooLarge = NewSystemError(ErrCodeProtocol, "method too large") ) // MetricsKey is a string representation of the error code that's suitable for // inclusion in metrics tags. func (c SystemErrCode) MetricsKey() string { switch c { case ErrCodeInvalid: // Shouldn't ever need this. return "invalid" case ErrCodeTimeout: return "timeout" case ErrCodeCancelled: return "cancelled" case ErrCodeBusy: return "busy" case ErrCodeDeclined: return "declined" case ErrCodeUnexpected: return "unexpected-error" case ErrCodeBadRequest: return "bad-request" case ErrCodeNetwork: return "network-error" case ErrCodeProtocol: return "protocol-error" default: return c.String() } } func (c SystemErrCode) relayMetricsKey() string { switch c { case ErrCodeInvalid: return "relay-invalid" case ErrCodeTimeout: return "relay-timeout" case ErrCodeCancelled: return "relay-cancelled" case ErrCodeBusy: return "relay-busy" case ErrCodeDeclined: return "relay-declined" case ErrCodeUnexpected: return "relay-unexpected-error" case ErrCodeBadRequest: return "relay-bad-request" case ErrCodeNetwork: return "relay-network-error" case ErrCodeProtocol: return "relay-protocol-error" default: return "relay-" + c.String() } } // A SystemError is a system-level error, containing an error code and message // TODO(mmihic): Probably we want to hide this interface, and let application code // just deal with standard raw errors. type SystemError struct { code SystemErrCode msg string wrapped error } // NewSystemError defines a new SystemError with a code and message func NewSystemError(code SystemErrCode, msg string, args ...interface{}) error { return SystemError{code: code, msg: fmt.Sprintf(msg, args...)} } // NewWrappedSystemError defines a new SystemError wrapping an existing error func NewWrappedSystemError(code SystemErrCode, wrapped error) error { if se, ok := wrapped.(SystemError); ok { return se } return SystemError{code: code, msg: fmt.Sprint(wrapped), wrapped: wrapped} } // Error returns the code and message, conforming to the error interface func (se SystemError) Error() string { return fmt.Sprintf("tchannel error %v: %s", se.Code(), se.msg) } // Wrapped returns the wrapped error func (se SystemError) Wrapped() error { return se.wrapped } // Code returns the SystemError code, for sending to a peer func (se SystemError) Code() SystemErrCode { return se.code } // Message returns the SystemError message. func (se SystemError) Message() string { return se.msg } // GetContextError converts the context error to a tchannel error. func GetContextError(err error) error { if err == context.DeadlineExceeded { return ErrTimeout } if err == context.Canceled { return ErrRequestCancelled } return err } // GetSystemErrorCode returns the code to report for the given error. If the error is a // SystemError, we can get the code directly. Otherwise treat it as an unexpected error func GetSystemErrorCode(err error) SystemErrCode { if err == nil { return ErrCodeInvalid } if se, ok := err.(SystemError); ok { return se.Code() } return ErrCodeUnexpected } // GetSystemErrorMessage returns the message to report for the given error. If the error is a // SystemError, we can get the underlying message. Otherwise, use the Error() method. func GetSystemErrorMessage(err error) string { if se, ok := err.(SystemError); ok { return se.Message() } return err.Error() } type errConnNotActive struct { info string state connectionState } func (e errConnNotActive) Error() string { return fmt.Sprintf("%v connection is not active: %v", e.info, e.state) } ================================================ FILE: errors_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "io" "regexp" "testing" "github.com/stretchr/testify/assert" ) func TestErrorMetricKeys(t *testing.T) { codes := []SystemErrCode{ ErrCodeInvalid, ErrCodeTimeout, ErrCodeCancelled, ErrCodeBusy, ErrCodeDeclined, ErrCodeUnexpected, ErrCodeBadRequest, ErrCodeNetwork, ErrCodeProtocol, } // Metrics keys should be all lowercase letters and dashes. No spaces, // underscores, or other characters. expected := regexp.MustCompile(`^[[:lower:]-]+$`) for _, c := range codes { assert.True(t, expected.MatchString(c.MetricsKey()), "Expected metrics key for code %s to be well-formed.", c.String()) } // Unexpected codes may have poorly-formed keys. assert.Equal(t, "SystemErrCode(13)", SystemErrCode(13).MetricsKey(), "Expected invalid error codes to use a fallback metrics key format.") } func TestInvalidError(t *testing.T) { code := GetSystemErrorCode(nil) assert.Equal(t, ErrCodeInvalid, code, "nil error should produce ErrCodeInvalid") } func TestUnexpectedError(t *testing.T) { code := GetSystemErrorCode(io.EOF) assert.Equal(t, ErrCodeUnexpected, code, "non-tchannel SystemError should produce ErrCodeUnexpected") } func TestSystemError(t *testing.T) { code := GetSystemErrorCode(ErrTimeout) assert.Equal(t, ErrCodeTimeout, code, "tchannel timeout error produces ErrCodeTimeout") } func TestRelayMetricsKey(t *testing.T) { for i := 0; i <= 256; i++ { code := SystemErrCode(i) assert.Equal(t, "relay-"+code.MetricsKey(), code.relayMetricsKey(), "Unexpected relay metrics key for %v", code) } } ================================================ FILE: examples/bench/client/client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "flag" "log" "net/http" _ "net/http/pprof" "runtime" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "go.uber.org/atomic" "golang.org/x/net/context" ) var ( hostPort = flag.String("hostPort", "localhost:12345", "listening socket of the bench server") numGoroutines = flag.Int("numGo", 1, "The number of goroutines to spawn") numOSThreads = flag.Int("numThreads", 1, "The number of OS threads to use (sets GOMAXPROCS)") setBlockSize = flag.Int("setBlockSize", 4096, "The size in bytes of the data being set") getToSetRatio = flag.Int("getToSetRatio", 1, "The number of Gets to do per Set call") // counter tracks the total number of requests completed in the past second. counter atomic.Int64 ) func main() { flag.Parse() runtime.GOMAXPROCS(*numOSThreads) // Sets up a listener for pprof. go func() { log.Println(http.ListenAndServe("localhost:6061", nil)) }() ch, err := tchannel.NewChannel("benchmark-client", nil) if err != nil { log.Fatalf("NewChannel failed: %v", err) } for i := 0; i < *numGoroutines; i++ { go worker(ch) } log.Printf("client config: %v workers on %v threads, setBlockSize %v, getToSetRatio %v", *numGoroutines, *numOSThreads, *setBlockSize, *getToSetRatio) requestCountReporter() } func requestCountReporter() { for { time.Sleep(time.Second) cur := counter.Swap(0) log.Printf("%v requests", cur) } } func worker(ch *tchannel.Channel) { data := make([]byte, *setBlockSize) for { if err := setRequest(ch, "key", string(data)); err != nil { log.Fatalf("set failed: %v", err) continue } counter.Inc() for i := 0; i < *getToSetRatio; i++ { _, err := getRequest(ch, "key") if err != nil { log.Fatalf("get failed: %v", err) } counter.Inc() } } } func setRequest(ch *tchannel.Channel, key, value string) error { ctx, _ := context.WithTimeout(context.Background(), time.Second*10) _, _, _, err := raw.Call(ctx, ch, *hostPort, "benchmark", "set", []byte(key), []byte(value)) return err } func getRequest(ch *tchannel.Channel, key string) (string, error) { ctx, _ := context.WithTimeout(context.Background(), time.Second) _, arg3, _, err := raw.Call(ctx, ch, *hostPort, "benchmark", "get", []byte(key), nil) return string(arg3), err } ================================================ FILE: examples/bench/runner.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "flag" "fmt" "log" "net" "os" "os/exec" "time" ) var ( flagHostPort = flag.String("hostPort", "127.0.0.1:12345", "The host:port to run the benchmark on") flagServerNumThreads = flag.Int("serverThreads", 1, "The number of OS threads to use for the server") flagServerBinary = flag.String("serverBinary", "./build/examples/bench/server", "Server binary location") flagClientBinary = flag.String("clientBinary", "./build/examples/bench/client", "Client binary location") flagProfileAfter = flag.Duration("profileAfter", 0, "Duration to wait before profiling. 0 disables profiling. Process is stopped after the profile.") flagProfileSeconds = flag.Int("profileSeconds", 30, "The number of seconds to profile") flagProfileStop = flag.Bool("profileStopProcess", true, "Whether to stop the benchmarks after profiling") ) func main() { flag.Parse() server, err := runServer(*flagHostPort, *flagServerNumThreads) if err != nil { log.Fatalf("Server failed: %v", err) } defer server.Process.Kill() client, err := runClient(flag.Args()) if err != nil { log.Fatalf("Client failed: %v", err) } defer client.Process.Kill() if *flagProfileAfter != 0 { go func() { time.Sleep(*flagProfileAfter) // Profile the server and the client, which have pprof endpoints at :6060, and :6061 p1, err := dumpProfile("localhost:6060", "server.pb.gz") if err != nil { log.Printf("Server profile failed: %v", err) } p2, err := dumpProfile("localhost:6061", "client.pb.gz") if err != nil { log.Printf("Client profile failed: %v", err) } if err := p1.Wait(); err != nil { log.Printf("Server profile error: %v", err) } if err := p2.Wait(); err != nil { log.Printf("Client profile error: %v", err) } if !*flagProfileStop { return } server.Process.Signal(os.Interrupt) client.Process.Signal(os.Interrupt) // After a while, kill the processes if we're still running. time.Sleep(300 * time.Millisecond) log.Printf("Still waiting for processes to stop, sending kill") server.Process.Kill() client.Process.Kill() }() } server.Wait() client.Wait() } func runServer(hostPort string, numThreads int) (*exec.Cmd, error) { host, port, err := net.SplitHostPort(hostPort) if err != nil { return nil, err } return runCmd(*flagServerBinary, "--host", host, "--port", port, "--numThreads", fmt.Sprint(numThreads)) } func runClient(clientArgs []string) (*exec.Cmd, error) { return runCmd(*flagClientBinary, clientArgs...) } func dumpProfile(baseHostPort, profileFile string) (*exec.Cmd, error) { profileURL := fmt.Sprintf("http://%v/debug/pprof/profile", baseHostPort) return runCmd("go", "tool", "pprof", "--proto", "--output="+profileFile, fmt.Sprintf("--seconds=%v", *flagProfileSeconds), profileURL) } func runCmd(cmdBinary string, args ...string) (*exec.Cmd, error) { cmd := exec.Command(cmdBinary, args...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd, cmd.Start() } ================================================ FILE: examples/bench/server/server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "errors" "flag" "fmt" "log" "net/http" _ "net/http/pprof" "runtime" "sync" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "golang.org/x/net/context" ) var ( flagHost = flag.String("host", "localhost", "The hostname to listen on") flagPort = flag.Int("port", 12345, "The base port to listen on") flagInstances = flag.Int("instances", 1, "The number of instances to start") flagOSThreads = flag.Int("numThreads", 1, "The number of OS threads to use (sets GOMAXPROCS)") ) func main() { flag.Parse() runtime.GOMAXPROCS(*flagOSThreads) // Sets up a listener for pprof. go func() { log.Printf("server pprof endpoint failed: %v", http.ListenAndServe("localhost:6060", nil)) }() for i := 0; i < *flagInstances; i++ { if err := setupServer(*flagHost, *flagPort, i); err != nil { log.Fatalf("setupServer %v failed: %v", i, err) } } log.Printf("server config: %v threads listening on %v:%v", *flagOSThreads, *flagHost, *flagPort) // Listen indefinitely. select {} } func setupServer(host string, basePort, instanceNum int) error { hostPort := fmt.Sprintf("%s:%v", host, basePort+instanceNum) ch, err := tchannel.NewChannel("benchmark", &tchannel.ChannelOptions{ ProcessName: fmt.Sprintf("benchmark-%v", instanceNum), }) if err != nil { return fmt.Errorf("NewChannel failed: %v", err) } handler := raw.Wrap(&kvHandler{vals: make(map[string]string)}) ch.Register(handler, "ping") ch.Register(handler, "get") ch.Register(handler, "set") if err := ch.ListenAndServe(hostPort); err != nil { return fmt.Errorf("ListenAndServe failed: %v", err) } return nil } type kvHandler struct { sync.RWMutex vals map[string]string } func (h *kvHandler) WithLock(write bool, f func()) { if write { h.Lock() } else { h.RLock() } f() if write { h.Unlock() } else { h.RUnlock() } } func (h *kvHandler) Ping(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg2: []byte("pong"), }, nil } func (h *kvHandler) Get(ctx context.Context, args *raw.Args) (*raw.Res, error) { var arg3 []byte h.WithLock(false /* write */, func() { arg3 = []byte(h.vals[string(args.Arg2)]) }) return &raw.Res{ Arg2: []byte(fmt.Sprint(len(arg3))), Arg3: arg3, }, nil } func (h *kvHandler) Set(ctx context.Context, args *raw.Args) (*raw.Res, error) { h.WithLock(true /* write */, func() { h.vals[string(args.Arg2)] = string(args.Arg3) }) return &raw.Res{ Arg2: []byte("ok"), Arg3: []byte("really ok"), }, nil } func (h *kvHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { switch args.Method { case "ping": return h.Ping(ctx, args) case "get": return h.Get(ctx, args) case "put": return h.Set(ctx, args) default: return nil, errors.New("unknown method") } } func (h *kvHandler) OnError(ctx context.Context, err error) { log.Fatalf("OnError %v", err) } ================================================ FILE: examples/hyperbahn/echo-server/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "fmt" "log" "net" "os" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/hyperbahn" "github.com/uber/tchannel-go/raw" "golang.org/x/net/context" ) func main() { tchan, err := tchannel.NewChannel("go-echo-server", nil) if err != nil { log.Fatalf("Failed to create channel: %v", err) } listenIP, err := tchannel.ListenIP() if err != nil { log.Fatalf("Failed to get IP to listen on: %v", err) } l, err := net.Listen("tcp", listenIP.String()+":61543") if err != nil { log.Fatalf("Could not listen: %v", err) } log.Printf("Listening on %v", l.Addr()) sc := tchan.GetSubChannel("go-echo-2") tchan.Register(raw.Wrap(handler{""}), "echo") sc.Register(raw.Wrap(handler{"subchannel:"}), "echo") tchan.Serve(l) if len(os.Args[1:]) == 0 { log.Fatalf("You must provide Hyperbahn nodes as arguments") } // advertise service with Hyperbahn. config := hyperbahn.Configuration{InitialNodes: os.Args[1:]} client, err := hyperbahn.NewClient(tchan, config, &hyperbahn.ClientOptions{ Handler: eventHandler{}, Timeout: time.Second, }) if err != nil { log.Fatalf("hyperbahn.NewClient failed: %v", err) } if err := client.Advertise(sc); err != nil { log.Fatalf("Advertise failed: %v", err) } // Server will keep running till Ctrl-C. select {} } type eventHandler struct{} func (eventHandler) On(event hyperbahn.Event) { fmt.Printf("On(%v)\n", event) } func (eventHandler) OnError(err error) { fmt.Printf("OnError(%v)\n", err) } type handler struct { prefix string } func (h handler) OnError(ctx context.Context, err error) { log.Fatalf("OnError: %v", err) } func (h handler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { arg2 := h.prefix + string(args.Arg2) arg3 := h.prefix + string(args.Arg3) return &raw.Res{ Arg2: []byte(arg2), Arg3: []byte(arg3), }, nil } ================================================ FILE: examples/hypercat/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "io" "log" "net" "os" "os/exec" "github.com/jessevdk/go-flags" "github.com/uber/tchannel-go" "golang.org/x/net/context" ) var options = struct { ServiceName string `short:"s" long:"service" required:"true" description:"The TChannel/Hyperbahn service name"` // MethodName can be specified multiple times to listen on multiple methods. MethodName []string `short:"o" long:"method" required:"true" description:"The method name to handle"` // HostPort can just be :port or port, in which case host defaults to tchannel's ListenIP. HostPort string `short:"l" long:"hostPort" default:":0" description:"The port or host:port to listen on"` MaxConcurrency int `short:"m" long:"maxSpawn" default:"1" description:"The maximum number concurrent processes"` Cmd struct { Command string `long:"command" description:"The command to execute" positional-arg-name:"command"` Args []string `long:"args" description:"The arguments to pass to the command" positional-arg-name:"args"` } `positional-args:"yes" required:"yes"` }{} var running chan struct{} func parseArgs() { var err error if _, err = flags.Parse(&options); err != nil { os.Exit(-1) } // Convert host port to a real host port. host, port, err := net.SplitHostPort(options.HostPort) if err != nil { port = options.HostPort } if host == "" { hostIP, err := tchannel.ListenIP() if err != nil { log.Printf("could not get ListenIP: %v, defaulting to 127.0.0.1", err) host = "127.0.0.1" } else { host = hostIP.String() } } options.HostPort = host + ":" + port running = make(chan struct{}, options.MaxConcurrency) } func main() { parseArgs() ch, err := tchannel.NewChannel(options.ServiceName, nil) if err != nil { log.Fatalf("NewChannel failed: %v", err) } for _, op := range options.MethodName { ch.Register(tchannel.HandlerFunc(handler), op) } if err := ch.ListenAndServe(options.HostPort); err != nil { log.Fatalf("ListenAndServe failed: %v", err) } peerInfo := ch.PeerInfo() log.Printf("listening for %v:%v on %v", peerInfo.ServiceName, options.MethodName, peerInfo.HostPort) select {} } func onError(msg string, args ...interface{}) { log.Fatalf(msg, args...) } func handler(ctx context.Context, call *tchannel.InboundCall) { running <- struct{}{} defer func() { <-running }() var arg2 []byte if err := tchannel.NewArgReader(call.Arg2Reader()).Read(&arg2); err != nil { log.Fatalf("Arg2Reader failed: %v", err) } arg3Reader, err := call.Arg3Reader() if err != nil { log.Fatalf("Arg3Reader failed: %v", err) } response := call.Response() if err := tchannel.NewArgWriter(response.Arg2Writer()).Write(nil); err != nil { log.Fatalf("Arg2Writer failed: %v", err) } arg3Writer, err := response.Arg3Writer() if err != nil { log.Fatalf("Arg3Writer failed: %v", err) } if err := spawnProcess(arg3Reader, arg3Writer); err != nil { log.Fatalf("spawnProcess failed: %v", err) } if err := arg3Reader.Close(); err != nil { log.Fatalf("Arg3Reader.Close failed: %v", err) } if err := arg3Writer.Close(); err != nil { log.Fatalf("Arg3Writer.Close failed: %v", err) } } func spawnProcess(reader io.Reader, writer io.Writer) error { cmd := exec.Command(options.Cmd.Command, options.Cmd.Args...) cmd.Stdin = reader cmd.Stdout = writer cmd.Stderr = os.Stderr return cmd.Run() } ================================================ FILE: examples/keyvalue/README.md ================================================ # Key-Value Store ```bash ./build/examples/keyvalue/server ./build/examples/keyvalue/client ``` This example exposes a simple key-value store over TChannel using the Thrift protocol. The client has an interactive CLI that can be used to make calls to the server. ================================================ FILE: examples/keyvalue/client/client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "bufio" "fmt" "log" "os" "strings" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/examples/keyvalue/gen-go/keyvalue" "github.com/uber/tchannel-go/hyperbahn" "github.com/uber/tchannel-go/thrift" ) var curUser = "anonymous" func printHelp() { fmt.Println("Usage:\n get [key]\n set [key] [value]") fmt.Println(" user [newUser]\n clearAll") } func main() { // Create a TChannel. ch, err := tchannel.NewChannel("keyvalue-client", nil) if err != nil { log.Fatalf("Failed to create tchannel: %v", err) } // Set up Hyperbahn client. config := hyperbahn.Configuration{InitialNodes: os.Args[1:]} if len(config.InitialNodes) == 0 { log.Fatalf("No Autobahn nodes to connect to given") } hyperbahn.NewClient(ch, config, nil) thriftClient := thrift.NewClient(ch, "keyvalue", nil) client := keyvalue.NewTChanKeyValueClient(thriftClient) adminClient := keyvalue.NewTChanAdminClient(thriftClient) // Read commands from the command line and execute them. scanner := bufio.NewScanner(os.Stdin) printHelp() fmt.Printf("> ") for scanner.Scan() { parts := strings.Split(scanner.Text(), " ") if parts[0] == "" { continue } switch parts[0] { case "help": printHelp() case "get": if len(parts) < 2 { printHelp() break } get(client, parts[1]) case "set": if len(parts) < 3 { printHelp() break } set(client, parts[1], parts[2]) case "user": if len(parts) < 2 { printHelp() break } curUser = parts[1] case "clearAll": clear(adminClient) default: log.Printf("Unsupported command %q\n", parts[0]) } fmt.Print("> ") } scanner.Text() } func get(client keyvalue.TChanKeyValue, key string) { ctx, cancel := createContext() defer cancel() val, err := client.Get(ctx, key) if err != nil { switch err := err.(type) { case *keyvalue.InvalidKey: log.Printf("Get %v failed: invalid key", key) case *keyvalue.KeyNotFound: log.Printf("Get %v failed: key not found", key) default: log.Printf("Get %v failed unexpectedly: %v", key, err) } return } log.Printf("Get %v: %v", key, val) } func set(client keyvalue.TChanKeyValue, key, value string) { ctx, cancel := createContext() defer cancel() if err := client.Set(ctx, key, value); err != nil { switch err := err.(type) { case *keyvalue.InvalidKey: log.Printf("Set %v failed: invalid key", key) default: log.Printf("Set %v:%v failed unexpectedly: %#v", key, value, err) } return } log.Printf("Set %v:%v succeeded with headers: %v", key, value, ctx.ResponseHeaders()) } func clear(adminClient keyvalue.TChanAdmin) { ctx, cancel := createContext() defer cancel() if err := adminClient.ClearAll(ctx); err != nil { switch err := err.(type) { case *keyvalue.NotAuthorized: log.Printf("You are not authorized to perform this method") default: log.Printf("ClearAll failed unexpectedly: %v", err) } return } log.Printf("ClearAll completed, all keys cleared") } func createContext() (thrift.Context, func()) { ctx, cancel := thrift.NewContext(time.Second) ctx = thrift.WithHeaders(ctx, map[string]string{"user": curUser}) return ctx, cancel } ================================================ FILE: examples/keyvalue/gen-go/keyvalue/admin.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package keyvalue import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type Admin interface { BaseService ClearAll() (err error) } type AdminClient struct { *BaseServiceClient } func NewAdminClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *AdminClient { return &AdminClient{BaseServiceClient: NewBaseServiceClientFactory(t, f)} } func NewAdminClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *AdminClient { return &AdminClient{BaseServiceClient: NewBaseServiceClientProtocol(t, iprot, oprot)} } func (p *AdminClient) ClearAll() (err error) { if err = p.sendClearAll(); err != nil { return } return p.recvClearAll() } func (p *AdminClient) sendClearAll() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("clearAll", thrift.CALL, p.SeqId); err != nil { return } args := AdminClearAllArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *AdminClient) recvClearAll() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "clearAll" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "clearAll failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "clearAll failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error12 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error13 error error13, err = error12.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error13 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "clearAll failed: invalid message type") return } result := AdminClearAllResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } if result.NotAuthorized != nil { err = result.NotAuthorized return } return } type AdminProcessor struct { *BaseServiceProcessor } func NewAdminProcessor(handler Admin) *AdminProcessor { self14 := &AdminProcessor{NewBaseServiceProcessor(handler)} self14.AddToProcessorMap("clearAll", &adminProcessorClearAll{handler: handler}) return self14 } type adminProcessorClearAll struct { handler Admin } func (p *adminProcessorClearAll) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := AdminClearAllArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("clearAll", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := AdminClearAllResult{} var err2 error if err2 = p.handler.ClearAll(); err2 != nil { switch v := err2.(type) { case *NotAuthorized: result.NotAuthorized = v default: x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing clearAll: "+err2.Error()) oprot.WriteMessageBegin("clearAll", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } } if err2 = oprot.WriteMessageBegin("clearAll", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES type AdminClearAllArgs struct { } func NewAdminClearAllArgs() *AdminClearAllArgs { return &AdminClearAllArgs{} } func (p *AdminClearAllArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *AdminClearAllArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("clearAll_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *AdminClearAllArgs) String() string { if p == nil { return "" } return fmt.Sprintf("AdminClearAllArgs(%+v)", *p) } // Attributes: // - NotAuthorized type AdminClearAllResult struct { NotAuthorized *NotAuthorized `thrift:"notAuthorized,1" db:"notAuthorized" json:"notAuthorized,omitempty"` } func NewAdminClearAllResult() *AdminClearAllResult { return &AdminClearAllResult{} } var AdminClearAllResult_NotAuthorized_DEFAULT *NotAuthorized func (p *AdminClearAllResult) GetNotAuthorized() *NotAuthorized { if !p.IsSetNotAuthorized() { return AdminClearAllResult_NotAuthorized_DEFAULT } return p.NotAuthorized } func (p *AdminClearAllResult) IsSetNotAuthorized() bool { return p.NotAuthorized != nil } func (p *AdminClearAllResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *AdminClearAllResult) ReadField1(iprot thrift.TProtocol) error { p.NotAuthorized = &NotAuthorized{} if err := p.NotAuthorized.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.NotAuthorized), err) } return nil } func (p *AdminClearAllResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("clearAll_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *AdminClearAllResult) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetNotAuthorized() { if err := oprot.WriteFieldBegin("notAuthorized", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:notAuthorized: ", p), err) } if err := p.NotAuthorized.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.NotAuthorized), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:notAuthorized: ", p), err) } } return err } func (p *AdminClearAllResult) String() string { if p == nil { return "" } return fmt.Sprintf("AdminClearAllResult(%+v)", *p) } ================================================ FILE: examples/keyvalue/gen-go/keyvalue/baseservice.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package keyvalue import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type BaseService interface { HealthCheck() (r string, err error) } type BaseServiceClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewBaseServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *BaseServiceClient { return &BaseServiceClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewBaseServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *BaseServiceClient { return &BaseServiceClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } func (p *BaseServiceClient) HealthCheck() (r string, err error) { if err = p.sendHealthCheck(); err != nil { return } return p.recvHealthCheck() } func (p *BaseServiceClient) sendHealthCheck() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("HealthCheck", thrift.CALL, p.SeqId); err != nil { return } args := BaseServiceHealthCheckArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *BaseServiceClient) recvHealthCheck() (value string, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "HealthCheck" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "HealthCheck failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "HealthCheck failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error0 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error1 error error1, err = error0.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error1 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "HealthCheck failed: invalid message type") return } result := BaseServiceHealthCheckResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } type BaseServiceProcessor struct { processorMap map[string]thrift.TProcessorFunction handler BaseService } func (p *BaseServiceProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *BaseServiceProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *BaseServiceProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewBaseServiceProcessor(handler BaseService) *BaseServiceProcessor { self2 := &BaseServiceProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self2.processorMap["HealthCheck"] = &baseServiceProcessorHealthCheck{handler: handler} return self2 } func (p *BaseServiceProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x3 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x3.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x3 } type baseServiceProcessorHealthCheck struct { handler BaseService } func (p *baseServiceProcessorHealthCheck) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := BaseServiceHealthCheckArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("HealthCheck", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := BaseServiceHealthCheckResult{} var retval string var err2 error if retval, err2 = p.handler.HealthCheck(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing HealthCheck: "+err2.Error()) oprot.WriteMessageBegin("HealthCheck", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = &retval } if err2 = oprot.WriteMessageBegin("HealthCheck", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES type BaseServiceHealthCheckArgs struct { } func NewBaseServiceHealthCheckArgs() *BaseServiceHealthCheckArgs { return &BaseServiceHealthCheckArgs{} } func (p *BaseServiceHealthCheckArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *BaseServiceHealthCheckArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("HealthCheck_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *BaseServiceHealthCheckArgs) String() string { if p == nil { return "" } return fmt.Sprintf("BaseServiceHealthCheckArgs(%+v)", *p) } // Attributes: // - Success type BaseServiceHealthCheckResult struct { Success *string `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewBaseServiceHealthCheckResult() *BaseServiceHealthCheckResult { return &BaseServiceHealthCheckResult{} } var BaseServiceHealthCheckResult_Success_DEFAULT string func (p *BaseServiceHealthCheckResult) GetSuccess() string { if !p.IsSetSuccess() { return BaseServiceHealthCheckResult_Success_DEFAULT } return *p.Success } func (p *BaseServiceHealthCheckResult) IsSetSuccess() bool { return p.Success != nil } func (p *BaseServiceHealthCheckResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *BaseServiceHealthCheckResult) ReadField0(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { p.Success = &v } return nil } func (p *BaseServiceHealthCheckResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("HealthCheck_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *BaseServiceHealthCheckResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := oprot.WriteString(string(*p.Success)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.success (0) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *BaseServiceHealthCheckResult) String() string { if p == nil { return "" } return fmt.Sprintf("BaseServiceHealthCheckResult(%+v)", *p) } ================================================ FILE: examples/keyvalue/gen-go/keyvalue/constants.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package keyvalue import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal func init() { } ================================================ FILE: examples/keyvalue/gen-go/keyvalue/keyvalue.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package keyvalue import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type KeyValue interface { BaseService // Parameters: // - Key Get(key string) (r string, err error) // Parameters: // - Key // - Value Set(key string, value string) (err error) } type KeyValueClient struct { *BaseServiceClient } func NewKeyValueClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *KeyValueClient { return &KeyValueClient{BaseServiceClient: NewBaseServiceClientFactory(t, f)} } func NewKeyValueClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *KeyValueClient { return &KeyValueClient{BaseServiceClient: NewBaseServiceClientProtocol(t, iprot, oprot)} } // Parameters: // - Key func (p *KeyValueClient) Get(key string) (r string, err error) { if err = p.sendGet(key); err != nil { return } return p.recvGet() } func (p *KeyValueClient) sendGet(key string) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Get", thrift.CALL, p.SeqId); err != nil { return } args := KeyValueGetArgs{ Key: key, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *KeyValueClient) recvGet() (value string, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Get" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Get failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Get failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error4 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error5 error error5, err = error4.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error5 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Get failed: invalid message type") return } result := KeyValueGetResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } if result.NotFound != nil { err = result.NotFound return } else if result.InvalidKey != nil { err = result.InvalidKey return } value = result.GetSuccess() return } // Parameters: // - Key // - Value func (p *KeyValueClient) Set(key string, value string) (err error) { if err = p.sendSet(key, value); err != nil { return } return p.recvSet() } func (p *KeyValueClient) sendSet(key string, value string) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Set", thrift.CALL, p.SeqId); err != nil { return } args := KeyValueSetArgs{ Key: key, Value: value, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *KeyValueClient) recvSet() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Set" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Set failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Set failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error6 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error7 error error7, err = error6.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error7 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Set failed: invalid message type") return } result := KeyValueSetResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } if result.InvalidKey != nil { err = result.InvalidKey return } return } type KeyValueProcessor struct { *BaseServiceProcessor } func NewKeyValueProcessor(handler KeyValue) *KeyValueProcessor { self8 := &KeyValueProcessor{NewBaseServiceProcessor(handler)} self8.AddToProcessorMap("Get", &keyValueProcessorGet{handler: handler}) self8.AddToProcessorMap("Set", &keyValueProcessorSet{handler: handler}) return self8 } type keyValueProcessorGet struct { handler KeyValue } func (p *keyValueProcessorGet) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := KeyValueGetArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Get", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := KeyValueGetResult{} var retval string var err2 error if retval, err2 = p.handler.Get(args.Key); err2 != nil { switch v := err2.(type) { case *KeyNotFound: result.NotFound = v case *InvalidKey: result.InvalidKey = v default: x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Get: "+err2.Error()) oprot.WriteMessageBegin("Get", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } } else { result.Success = &retval } if err2 = oprot.WriteMessageBegin("Get", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type keyValueProcessorSet struct { handler KeyValue } func (p *keyValueProcessorSet) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := KeyValueSetArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Set", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := KeyValueSetResult{} var err2 error if err2 = p.handler.Set(args.Key, args.Value); err2 != nil { switch v := err2.(type) { case *InvalidKey: result.InvalidKey = v default: x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Set: "+err2.Error()) oprot.WriteMessageBegin("Set", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } } if err2 = oprot.WriteMessageBegin("Set", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES // Attributes: // - Key type KeyValueGetArgs struct { Key string `thrift:"key,1" db:"key" json:"key"` } func NewKeyValueGetArgs() *KeyValueGetArgs { return &KeyValueGetArgs{} } func (p *KeyValueGetArgs) GetKey() string { return p.Key } func (p *KeyValueGetArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *KeyValueGetArgs) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Key = v } return nil } func (p *KeyValueGetArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Get_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *KeyValueGetArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("key", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:key: ", p), err) } if err := oprot.WriteString(string(p.Key)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.key (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:key: ", p), err) } return err } func (p *KeyValueGetArgs) String() string { if p == nil { return "" } return fmt.Sprintf("KeyValueGetArgs(%+v)", *p) } // Attributes: // - Success // - NotFound // - InvalidKey type KeyValueGetResult struct { Success *string `thrift:"success,0" db:"success" json:"success,omitempty"` NotFound *KeyNotFound `thrift:"notFound,1" db:"notFound" json:"notFound,omitempty"` InvalidKey *InvalidKey `thrift:"invalidKey,2" db:"invalidKey" json:"invalidKey,omitempty"` } func NewKeyValueGetResult() *KeyValueGetResult { return &KeyValueGetResult{} } var KeyValueGetResult_Success_DEFAULT string func (p *KeyValueGetResult) GetSuccess() string { if !p.IsSetSuccess() { return KeyValueGetResult_Success_DEFAULT } return *p.Success } var KeyValueGetResult_NotFound_DEFAULT *KeyNotFound func (p *KeyValueGetResult) GetNotFound() *KeyNotFound { if !p.IsSetNotFound() { return KeyValueGetResult_NotFound_DEFAULT } return p.NotFound } var KeyValueGetResult_InvalidKey_DEFAULT *InvalidKey func (p *KeyValueGetResult) GetInvalidKey() *InvalidKey { if !p.IsSetInvalidKey() { return KeyValueGetResult_InvalidKey_DEFAULT } return p.InvalidKey } func (p *KeyValueGetResult) IsSetSuccess() bool { return p.Success != nil } func (p *KeyValueGetResult) IsSetNotFound() bool { return p.NotFound != nil } func (p *KeyValueGetResult) IsSetInvalidKey() bool { return p.InvalidKey != nil } func (p *KeyValueGetResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } case 1: if err := p.ReadField1(iprot); err != nil { return err } case 2: if err := p.ReadField2(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *KeyValueGetResult) ReadField0(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { p.Success = &v } return nil } func (p *KeyValueGetResult) ReadField1(iprot thrift.TProtocol) error { p.NotFound = &KeyNotFound{} if err := p.NotFound.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.NotFound), err) } return nil } func (p *KeyValueGetResult) ReadField2(iprot thrift.TProtocol) error { p.InvalidKey = &InvalidKey{} if err := p.InvalidKey.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.InvalidKey), err) } return nil } func (p *KeyValueGetResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Get_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *KeyValueGetResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := oprot.WriteString(string(*p.Success)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.success (0) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *KeyValueGetResult) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetNotFound() { if err := oprot.WriteFieldBegin("notFound", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:notFound: ", p), err) } if err := p.NotFound.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.NotFound), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:notFound: ", p), err) } } return err } func (p *KeyValueGetResult) writeField2(oprot thrift.TProtocol) (err error) { if p.IsSetInvalidKey() { if err := oprot.WriteFieldBegin("invalidKey", thrift.STRUCT, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:invalidKey: ", p), err) } if err := p.InvalidKey.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.InvalidKey), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:invalidKey: ", p), err) } } return err } func (p *KeyValueGetResult) String() string { if p == nil { return "" } return fmt.Sprintf("KeyValueGetResult(%+v)", *p) } // Attributes: // - Key // - Value type KeyValueSetArgs struct { Key string `thrift:"key,1" db:"key" json:"key"` Value string `thrift:"value,2" db:"value" json:"value"` } func NewKeyValueSetArgs() *KeyValueSetArgs { return &KeyValueSetArgs{} } func (p *KeyValueSetArgs) GetKey() string { return p.Key } func (p *KeyValueSetArgs) GetValue() string { return p.Value } func (p *KeyValueSetArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } case 2: if err := p.ReadField2(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *KeyValueSetArgs) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Key = v } return nil } func (p *KeyValueSetArgs) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.Value = v } return nil } func (p *KeyValueSetArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Set_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *KeyValueSetArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("key", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:key: ", p), err) } if err := oprot.WriteString(string(p.Key)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.key (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:key: ", p), err) } return err } func (p *KeyValueSetArgs) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("value", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:value: ", p), err) } if err := oprot.WriteString(string(p.Value)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.value (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:value: ", p), err) } return err } func (p *KeyValueSetArgs) String() string { if p == nil { return "" } return fmt.Sprintf("KeyValueSetArgs(%+v)", *p) } // Attributes: // - InvalidKey type KeyValueSetResult struct { InvalidKey *InvalidKey `thrift:"invalidKey,1" db:"invalidKey" json:"invalidKey,omitempty"` } func NewKeyValueSetResult() *KeyValueSetResult { return &KeyValueSetResult{} } var KeyValueSetResult_InvalidKey_DEFAULT *InvalidKey func (p *KeyValueSetResult) GetInvalidKey() *InvalidKey { if !p.IsSetInvalidKey() { return KeyValueSetResult_InvalidKey_DEFAULT } return p.InvalidKey } func (p *KeyValueSetResult) IsSetInvalidKey() bool { return p.InvalidKey != nil } func (p *KeyValueSetResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *KeyValueSetResult) ReadField1(iprot thrift.TProtocol) error { p.InvalidKey = &InvalidKey{} if err := p.InvalidKey.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.InvalidKey), err) } return nil } func (p *KeyValueSetResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Set_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *KeyValueSetResult) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetInvalidKey() { if err := oprot.WriteFieldBegin("invalidKey", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:invalidKey: ", p), err) } if err := p.InvalidKey.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.InvalidKey), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:invalidKey: ", p), err) } } return err } func (p *KeyValueSetResult) String() string { if p == nil { return "" } return fmt.Sprintf("KeyValueSetResult(%+v)", *p) } ================================================ FILE: examples/keyvalue/gen-go/keyvalue/tchan-keyvalue.go ================================================ // @generated Code generated by thrift-gen. Do not modify. // Package keyvalue is generated code used to make or handle TChannel calls using Thrift. package keyvalue import ( "fmt" athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "github.com/uber/tchannel-go/thrift" ) // Interfaces for the service and client for the services defined in the IDL. // TChanAdmin is the interface that defines the server handler and client interface. type TChanAdmin interface { TChanBaseService ClearAll(ctx thrift.Context) error } // TChanKeyValue is the interface that defines the server handler and client interface. type TChanKeyValue interface { TChanBaseService Get(ctx thrift.Context, key string) (string, error) Set(ctx thrift.Context, key string, value string) error } // TChanBaseService is the interface that defines the server handler and client interface. type TChanBaseService interface { HealthCheck(ctx thrift.Context) (string, error) } // Implementation of a client and service handler. type tchanAdminClient struct { TChanBaseService thriftService string client thrift.TChanClient } func NewTChanAdminInheritedClient(thriftService string, client thrift.TChanClient) *tchanAdminClient { return &tchanAdminClient{ NewTChanBaseServiceInheritedClient(thriftService, client), thriftService, client, } } // NewTChanAdminClient creates a client that can be used to make remote calls. func NewTChanAdminClient(client thrift.TChanClient) TChanAdmin { return NewTChanAdminInheritedClient("Admin", client) } func (c *tchanAdminClient) ClearAll(ctx thrift.Context) error { var resp AdminClearAllResult args := AdminClearAllArgs{} success, err := c.client.Call(ctx, c.thriftService, "clearAll", &args, &resp) if err == nil && !success { switch { case resp.NotAuthorized != nil: err = resp.NotAuthorized default: err = fmt.Errorf("received no result or unknown exception for clearAll") } } return err } type tchanAdminServer struct { thrift.TChanServer handler TChanAdmin } // NewTChanAdminServer wraps a handler for TChanAdmin so it can be // registered with a thrift.Server. func NewTChanAdminServer(handler TChanAdmin) thrift.TChanServer { return &tchanAdminServer{ NewTChanBaseServiceServer(handler), handler, } } func (s *tchanAdminServer) Service() string { return "Admin" } func (s *tchanAdminServer) Methods() []string { return []string{ "clearAll", "HealthCheck", } } func (s *tchanAdminServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "clearAll": return s.handleClearAll(ctx, protocol) case "HealthCheck": return s.TChanServer.Handle(ctx, methodName, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanAdminServer) handleClearAll(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req AdminClearAllArgs var res AdminClearAllResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.ClearAll(ctx) if err != nil { switch v := err.(type) { case *NotAuthorized: if v == nil { return false, nil, fmt.Errorf("Handler for notAuthorized returned non-nil error type *NotAuthorized but nil value") } res.NotAuthorized = v default: return false, nil, err } } else { } return err == nil, &res, nil } type tchanKeyValueClient struct { TChanBaseService thriftService string client thrift.TChanClient } func NewTChanKeyValueInheritedClient(thriftService string, client thrift.TChanClient) *tchanKeyValueClient { return &tchanKeyValueClient{ NewTChanBaseServiceInheritedClient(thriftService, client), thriftService, client, } } // NewTChanKeyValueClient creates a client that can be used to make remote calls. func NewTChanKeyValueClient(client thrift.TChanClient) TChanKeyValue { return NewTChanKeyValueInheritedClient("KeyValue", client) } func (c *tchanKeyValueClient) Get(ctx thrift.Context, key string) (string, error) { var resp KeyValueGetResult args := KeyValueGetArgs{ Key: key, } success, err := c.client.Call(ctx, c.thriftService, "Get", &args, &resp) if err == nil && !success { switch { case resp.NotFound != nil: err = resp.NotFound case resp.InvalidKey != nil: err = resp.InvalidKey default: err = fmt.Errorf("received no result or unknown exception for Get") } } return resp.GetSuccess(), err } func (c *tchanKeyValueClient) Set(ctx thrift.Context, key string, value string) error { var resp KeyValueSetResult args := KeyValueSetArgs{ Key: key, Value: value, } success, err := c.client.Call(ctx, c.thriftService, "Set", &args, &resp) if err == nil && !success { switch { case resp.InvalidKey != nil: err = resp.InvalidKey default: err = fmt.Errorf("received no result or unknown exception for Set") } } return err } type tchanKeyValueServer struct { thrift.TChanServer handler TChanKeyValue } // NewTChanKeyValueServer wraps a handler for TChanKeyValue so it can be // registered with a thrift.Server. func NewTChanKeyValueServer(handler TChanKeyValue) thrift.TChanServer { return &tchanKeyValueServer{ NewTChanBaseServiceServer(handler), handler, } } func (s *tchanKeyValueServer) Service() string { return "KeyValue" } func (s *tchanKeyValueServer) Methods() []string { return []string{ "Get", "Set", "HealthCheck", } } func (s *tchanKeyValueServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "Get": return s.handleGet(ctx, protocol) case "Set": return s.handleSet(ctx, protocol) case "HealthCheck": return s.TChanServer.Handle(ctx, methodName, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanKeyValueServer) handleGet(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req KeyValueGetArgs var res KeyValueGetResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Get(ctx, req.Key) if err != nil { switch v := err.(type) { case *KeyNotFound: if v == nil { return false, nil, fmt.Errorf("Handler for notFound returned non-nil error type *KeyNotFound but nil value") } res.NotFound = v case *InvalidKey: if v == nil { return false, nil, fmt.Errorf("Handler for invalidKey returned non-nil error type *InvalidKey but nil value") } res.InvalidKey = v default: return false, nil, err } } else { res.Success = &r } return err == nil, &res, nil } func (s *tchanKeyValueServer) handleSet(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req KeyValueSetArgs var res KeyValueSetResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.Set(ctx, req.Key, req.Value) if err != nil { switch v := err.(type) { case *InvalidKey: if v == nil { return false, nil, fmt.Errorf("Handler for invalidKey returned non-nil error type *InvalidKey but nil value") } res.InvalidKey = v default: return false, nil, err } } else { } return err == nil, &res, nil } type tchanBaseServiceClient struct { thriftService string client thrift.TChanClient } func NewTChanBaseServiceInheritedClient(thriftService string, client thrift.TChanClient) *tchanBaseServiceClient { return &tchanBaseServiceClient{ thriftService, client, } } // NewTChanBaseServiceClient creates a client that can be used to make remote calls. func NewTChanBaseServiceClient(client thrift.TChanClient) TChanBaseService { return NewTChanBaseServiceInheritedClient("baseService", client) } func (c *tchanBaseServiceClient) HealthCheck(ctx thrift.Context) (string, error) { var resp BaseServiceHealthCheckResult args := BaseServiceHealthCheckArgs{} success, err := c.client.Call(ctx, c.thriftService, "HealthCheck", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for HealthCheck") } } return resp.GetSuccess(), err } type tchanBaseServiceServer struct { handler TChanBaseService } // NewTChanBaseServiceServer wraps a handler for TChanBaseService so it can be // registered with a thrift.Server. func NewTChanBaseServiceServer(handler TChanBaseService) thrift.TChanServer { return &tchanBaseServiceServer{ handler, } } func (s *tchanBaseServiceServer) Service() string { return "baseService" } func (s *tchanBaseServiceServer) Methods() []string { return []string{ "HealthCheck", } } func (s *tchanBaseServiceServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "HealthCheck": return s.handleHealthCheck(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanBaseServiceServer) handleHealthCheck(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req BaseServiceHealthCheckArgs var res BaseServiceHealthCheckResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.HealthCheck(ctx) if err != nil { return false, nil, err } else { res.Success = &r } return err == nil, &res, nil } ================================================ FILE: examples/keyvalue/gen-go/keyvalue/ttypes.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package keyvalue import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal var GoUnusedProtection__ int // Attributes: // - Key type KeyNotFound struct { Key string `thrift:"key,1" db:"key" json:"key"` } func NewKeyNotFound() *KeyNotFound { return &KeyNotFound{} } func (p *KeyNotFound) GetKey() string { return p.Key } func (p *KeyNotFound) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *KeyNotFound) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Key = v } return nil } func (p *KeyNotFound) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("KeyNotFound"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *KeyNotFound) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("key", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:key: ", p), err) } if err := oprot.WriteString(string(p.Key)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.key (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:key: ", p), err) } return err } func (p *KeyNotFound) String() string { if p == nil { return "" } return fmt.Sprintf("KeyNotFound(%+v)", *p) } func (p *KeyNotFound) Error() string { return p.String() } type InvalidKey struct { } func NewInvalidKey() *InvalidKey { return &InvalidKey{} } func (p *InvalidKey) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *InvalidKey) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("InvalidKey"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *InvalidKey) String() string { if p == nil { return "" } return fmt.Sprintf("InvalidKey(%+v)", *p) } func (p *InvalidKey) Error() string { return p.String() } type NotAuthorized struct { } func NewNotAuthorized() *NotAuthorized { return &NotAuthorized{} } func (p *NotAuthorized) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *NotAuthorized) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("NotAuthorized"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *NotAuthorized) String() string { if p == nil { return "" } return fmt.Sprintf("NotAuthorized(%+v)", *p) } func (p *NotAuthorized) Error() string { return p.String() } ================================================ FILE: examples/keyvalue/keyvalue.thrift ================================================ service baseService { string HealthCheck() } exception KeyNotFound { 1: string key } exception InvalidKey {} service KeyValue extends baseService { // If the key does not start with a letter, InvalidKey is returned. // If the key does not exist, KeyNotFound is returned. string Get(1: string key) throws ( 1: KeyNotFound notFound 2: InvalidKey invalidKey) // Set returns InvalidKey is an invalid key is sent. void Set(1: string key, 2: string value) throws ( 1: InvalidKey invalidKey ) } // Returned when the user is not authorized for the Admin service. exception NotAuthorized {} service Admin extends baseService { void clearAll() throws (1: NotAuthorized notAuthorized) } ================================================ FILE: examples/keyvalue/server/server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "fmt" "log" "os" "sync" "unicode" "unicode/utf8" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/examples/keyvalue/gen-go/keyvalue" "github.com/uber/tchannel-go/hyperbahn" "github.com/uber/tchannel-go/pprof" "github.com/uber/tchannel-go/thrift" ) func main() { // Create a TChannel and register the Thrift handlers. ch, err := tchannel.NewChannel("keyvalue", nil) if err != nil { log.Fatalf("Failed to create tchannel: %v", err) } // Register both the KeyValue and Admin services. // We can register multiple Thrift services on a single Hyperbahn service. h := newKVHandler() server := thrift.NewServer(ch) server.Register(keyvalue.NewTChanKeyValueServer(h)) server.Register(keyvalue.NewTChanAdminServer(h)) pprof.Register(ch) // Listen for connections on the external interface so we can receive connections. ip, err := tchannel.ListenIP() if err != nil { log.Fatalf("Failed to find IP to Listen on: %v", err) } // We use port 0 which asks the OS to assign any available port. // Static port allocations are not necessary for services on Hyperbahn. ch.ListenAndServe(fmt.Sprintf("%v:%v", ip, 0)) // Advertising registers this service instance with Hyperbahn so // that Hyperbahn can route requests for "keyvalue" to us. config := hyperbahn.Configuration{InitialNodes: os.Args[1:]} if len(config.InitialNodes) > 0 { client, err := hyperbahn.NewClient(ch, config, nil) if err != nil { log.Fatalf("hyperbahn.NewClient failed: %v", err) } if err := client.Advertise(); err != nil { log.Fatalf("Hyperbahn advertise failed: %v", err) } } // The service is now started up, run it till we receive a ctrl-c. log.Printf("KeyValue service has started on %v", ch.PeerInfo().HostPort) select {} } type kvHandler struct { sync.RWMutex vals map[string]string } // NewKVHandler returns a new handler for the KeyValue service. func newKVHandler() *kvHandler { return &kvHandler{vals: make(map[string]string)} } // Get returns the value stored for the given key. func (h *kvHandler) Get(ctx thrift.Context, key string) (string, error) { if err := isValidKey(key); err != nil { return "", err } h.RLock() defer h.RUnlock() if val, ok := h.vals[key]; ok { return val, nil } return "", &keyvalue.KeyNotFound{Key: key} } // Set sets the value for a given key. func (h *kvHandler) Set(ctx thrift.Context, key, value string) error { if err := isValidKey(key); err != nil { return err } h.Lock() defer h.Unlock() h.vals[key] = value // Example of how to use response headers. Normally, these values should be passed via result structs. ctx.SetResponseHeaders(map[string]string{"count": fmt.Sprint(len(h.vals))}) return nil } // HealthCheck return the health status of this process. func (h *kvHandler) HealthCheck(ctx thrift.Context) (string, error) { return "OK", nil } // ClearAll clears all the keys. func (h *kvHandler) ClearAll(ctx thrift.Context) error { if !isAdmin(ctx) { return &keyvalue.NotAuthorized{} } h.Lock() defer h.Unlock() h.vals = make(map[string]string) return nil } func isValidKey(key string) error { r, _ := utf8.DecodeRuneInString(key) if !unicode.IsLetter(r) { return &keyvalue.InvalidKey{} } return nil } func isAdmin(ctx thrift.Context) bool { return ctx.Headers()["user"] == "admin" } ================================================ FILE: examples/ping/README.md ================================================ # Ping-Pong ```bash ./build/examples/ping/pong ``` This example creates a client and server channel. The server channel registers a `PingService` with a `ping` method, which takes request `Headers` and a `Ping` body and returns the same `Headers` along with a `Pong` body. The client sends a ping request to the server. Note that every instance is bidirectional, so the same channel can be used for both sending and receiving requests to peers. New connections are initiated on demand. ================================================ FILE: examples/ping/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "fmt" "time" "golang.org/x/net/context" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/json" ) var log = tchannel.SimpleLogger // Ping is the ping request type. type Ping struct { Message string `json:"message"` } // Pong is the ping response type. type Pong Ping func pingHandler(ctx json.Context, ping *Ping) (*Pong, error) { return &Pong{ Message: fmt.Sprintf("ping %v", ping), }, nil } func pingOtherHandler(ctx json.Context, ping *Ping) (*Pong, error) { return &Pong{ Message: fmt.Sprintf("pingOther %v", ping), }, nil } func onError(ctx context.Context, err error) { log.WithFields(tchannel.ErrField(err)).Fatal("onError handler triggered.") } func listenAndHandle(s *tchannel.Channel, hostPort string) { log.Infof("Service %s", hostPort) // If no error is returned, the listen was successful. Serving happens in the background. if err := s.ListenAndServe(hostPort); err != nil { log.WithFields( tchannel.LogField{Key: "hostPort", Value: hostPort}, tchannel.ErrField(err), ).Fatal("Couldn't listen.") } } func main() { // Create a new TChannel for handling requests ch, err := tchannel.NewChannel("PingService", &tchannel.ChannelOptions{Logger: tchannel.SimpleLogger}) if err != nil { log.WithFields(tchannel.ErrField(err)).Fatal("Couldn't create new channel.") } // Register a handler for the ping message on the PingService json.Register(ch, json.Handlers{ "ping": pingHandler, }, onError) // Listen for incoming requests listenAndHandle(ch, "127.0.0.1:10500") // Create a new TChannel for sending requests. client, err := tchannel.NewChannel("ping-client", nil) if err != nil { log.WithFields(tchannel.ErrField(err)).Fatal("Couldn't create new client channel.") } // Make a call to ourselves, with a timeout of 10s ctx, cancel := json.NewContext(time.Second * 10) defer cancel() peer := client.Peers().Add(ch.PeerInfo().HostPort) var pong Pong if err := json.CallPeer(ctx, peer, "PingService", "ping", &Ping{"Hello World"}, &pong); err != nil { log.WithFields(tchannel.ErrField(err)).Fatal("json.Call failed.") } log.Infof("Received pong: %s", pong.Message) // Create a new subchannel for the top-level channel subCh := ch.GetSubChannel("PingServiceOther") // Register a handler on the subchannel json.Register(subCh, json.Handlers{ "pingOther": pingOtherHandler, }, onError) // Try to send a message to the Service:Method pair for the subchannel if err := json.CallPeer(ctx, peer, "PingServiceOther", "pingOther", &Ping{"Hello Other World"}, &pong); err != nil { log.WithFields(tchannel.ErrField(err)).Fatal("json.Call failed.") } log.Infof("Received pong: %s", pong.Message) } ================================================ FILE: examples/test_server/server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "flag" "fmt" "log" "golang.org/x/net/context" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" ) var ( flagHost = flag.String("host", "localhost", "The hostname to serve on") flagPort = flag.Int("port", 0, "The port to listen on") ) type rawHandler struct{} func (rawHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, }, nil } func (rawHandler) OnError(ctx context.Context, err error) { log.Fatalf("OnError: %v", err) } func main() { flag.Parse() ch, err := tchannel.NewChannel("test_as_raw", nil) if err != nil { log.Fatalf("NewChannel failed: %v", err) } handler := raw.Wrap(rawHandler{}) ch.Register(handler, "echo") ch.Register(handler, "streaming_echo") hostPort := fmt.Sprintf("%s:%v", *flagHost, *flagPort) if err := ch.ListenAndServe(hostPort); err != nil { log.Fatalf("ListenAndServe failed: %v", err) } fmt.Println("listening on", ch.PeerInfo().HostPort) select {} } ================================================ FILE: examples/thrift/example.thrift ================================================ struct HealthCheckRes { 1: bool healthy, 2: string msg, } service Base { void BaseCall() } service First extends Base { string Echo(1:string msg) HealthCheckRes Healthcheck() void AppError() } service Second { void Test() } ================================================ FILE: examples/thrift/gen-go/example/base.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package example import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type Base interface { BaseCall() (err error) } type BaseClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewBaseClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *BaseClient { return &BaseClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewBaseClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *BaseClient { return &BaseClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } func (p *BaseClient) BaseCall() (err error) { if err = p.sendBaseCall(); err != nil { return } return p.recvBaseCall() } func (p *BaseClient) sendBaseCall() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("BaseCall", thrift.CALL, p.SeqId); err != nil { return } args := BaseBaseCallArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *BaseClient) recvBaseCall() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "BaseCall" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "BaseCall failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "BaseCall failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error0 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error1 error error1, err = error0.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error1 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "BaseCall failed: invalid message type") return } result := BaseBaseCallResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } return } type BaseProcessor struct { processorMap map[string]thrift.TProcessorFunction handler Base } func (p *BaseProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *BaseProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *BaseProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewBaseProcessor(handler Base) *BaseProcessor { self2 := &BaseProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self2.processorMap["BaseCall"] = &baseProcessorBaseCall{handler: handler} return self2 } func (p *BaseProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x3 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x3.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x3 } type baseProcessorBaseCall struct { handler Base } func (p *baseProcessorBaseCall) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := BaseBaseCallArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("BaseCall", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := BaseBaseCallResult{} var err2 error if err2 = p.handler.BaseCall(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing BaseCall: "+err2.Error()) oprot.WriteMessageBegin("BaseCall", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } if err2 = oprot.WriteMessageBegin("BaseCall", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES type BaseBaseCallArgs struct { } func NewBaseBaseCallArgs() *BaseBaseCallArgs { return &BaseBaseCallArgs{} } func (p *BaseBaseCallArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *BaseBaseCallArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("BaseCall_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *BaseBaseCallArgs) String() string { if p == nil { return "" } return fmt.Sprintf("BaseBaseCallArgs(%+v)", *p) } type BaseBaseCallResult struct { } func NewBaseBaseCallResult() *BaseBaseCallResult { return &BaseBaseCallResult{} } func (p *BaseBaseCallResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *BaseBaseCallResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("BaseCall_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *BaseBaseCallResult) String() string { if p == nil { return "" } return fmt.Sprintf("BaseBaseCallResult(%+v)", *p) } ================================================ FILE: examples/thrift/gen-go/example/constants.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package example import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal func init() { } ================================================ FILE: examples/thrift/gen-go/example/first.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package example import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type First interface { Base // Parameters: // - Msg Echo(msg string) (r string, err error) Healthcheck() (r *HealthCheckRes, err error) AppError() (err error) } type FirstClient struct { *BaseClient } func NewFirstClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *FirstClient { return &FirstClient{BaseClient: NewBaseClientFactory(t, f)} } func NewFirstClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *FirstClient { return &FirstClient{BaseClient: NewBaseClientProtocol(t, iprot, oprot)} } // Parameters: // - Msg func (p *FirstClient) Echo(msg string) (r string, err error) { if err = p.sendEcho(msg); err != nil { return } return p.recvEcho() } func (p *FirstClient) sendEcho(msg string) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Echo", thrift.CALL, p.SeqId); err != nil { return } args := FirstEchoArgs{ Msg: msg, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *FirstClient) recvEcho() (value string, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Echo" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Echo failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Echo failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error4 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error5 error error5, err = error4.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error5 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Echo failed: invalid message type") return } result := FirstEchoResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } func (p *FirstClient) Healthcheck() (r *HealthCheckRes, err error) { if err = p.sendHealthcheck(); err != nil { return } return p.recvHealthcheck() } func (p *FirstClient) sendHealthcheck() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Healthcheck", thrift.CALL, p.SeqId); err != nil { return } args := FirstHealthcheckArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *FirstClient) recvHealthcheck() (value *HealthCheckRes, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Healthcheck" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Healthcheck failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Healthcheck failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error6 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error7 error error7, err = error6.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error7 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Healthcheck failed: invalid message type") return } result := FirstHealthcheckResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } func (p *FirstClient) AppError() (err error) { if err = p.sendAppError(); err != nil { return } return p.recvAppError() } func (p *FirstClient) sendAppError() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("AppError", thrift.CALL, p.SeqId); err != nil { return } args := FirstAppErrorArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *FirstClient) recvAppError() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "AppError" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "AppError failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "AppError failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error8 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error9 error error9, err = error8.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error9 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "AppError failed: invalid message type") return } result := FirstAppErrorResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } return } type FirstProcessor struct { *BaseProcessor } func NewFirstProcessor(handler First) *FirstProcessor { self10 := &FirstProcessor{NewBaseProcessor(handler)} self10.AddToProcessorMap("Echo", &firstProcessorEcho{handler: handler}) self10.AddToProcessorMap("Healthcheck", &firstProcessorHealthcheck{handler: handler}) self10.AddToProcessorMap("AppError", &firstProcessorAppError{handler: handler}) return self10 } type firstProcessorEcho struct { handler First } func (p *firstProcessorEcho) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := FirstEchoArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Echo", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := FirstEchoResult{} var retval string var err2 error if retval, err2 = p.handler.Echo(args.Msg); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Echo: "+err2.Error()) oprot.WriteMessageBegin("Echo", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = &retval } if err2 = oprot.WriteMessageBegin("Echo", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type firstProcessorHealthcheck struct { handler First } func (p *firstProcessorHealthcheck) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := FirstHealthcheckArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Healthcheck", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := FirstHealthcheckResult{} var retval *HealthCheckRes var err2 error if retval, err2 = p.handler.Healthcheck(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Healthcheck: "+err2.Error()) oprot.WriteMessageBegin("Healthcheck", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("Healthcheck", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type firstProcessorAppError struct { handler First } func (p *firstProcessorAppError) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := FirstAppErrorArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("AppError", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := FirstAppErrorResult{} var err2 error if err2 = p.handler.AppError(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing AppError: "+err2.Error()) oprot.WriteMessageBegin("AppError", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } if err2 = oprot.WriteMessageBegin("AppError", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES // Attributes: // - Msg type FirstEchoArgs struct { Msg string `thrift:"msg,1" db:"msg" json:"msg"` } func NewFirstEchoArgs() *FirstEchoArgs { return &FirstEchoArgs{} } func (p *FirstEchoArgs) GetMsg() string { return p.Msg } func (p *FirstEchoArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *FirstEchoArgs) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Msg = v } return nil } func (p *FirstEchoArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Echo_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *FirstEchoArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("msg", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:msg: ", p), err) } if err := oprot.WriteString(string(p.Msg)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.msg (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:msg: ", p), err) } return err } func (p *FirstEchoArgs) String() string { if p == nil { return "" } return fmt.Sprintf("FirstEchoArgs(%+v)", *p) } // Attributes: // - Success type FirstEchoResult struct { Success *string `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewFirstEchoResult() *FirstEchoResult { return &FirstEchoResult{} } var FirstEchoResult_Success_DEFAULT string func (p *FirstEchoResult) GetSuccess() string { if !p.IsSetSuccess() { return FirstEchoResult_Success_DEFAULT } return *p.Success } func (p *FirstEchoResult) IsSetSuccess() bool { return p.Success != nil } func (p *FirstEchoResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *FirstEchoResult) ReadField0(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { p.Success = &v } return nil } func (p *FirstEchoResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Echo_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *FirstEchoResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := oprot.WriteString(string(*p.Success)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.success (0) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *FirstEchoResult) String() string { if p == nil { return "" } return fmt.Sprintf("FirstEchoResult(%+v)", *p) } type FirstHealthcheckArgs struct { } func NewFirstHealthcheckArgs() *FirstHealthcheckArgs { return &FirstHealthcheckArgs{} } func (p *FirstHealthcheckArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *FirstHealthcheckArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Healthcheck_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *FirstHealthcheckArgs) String() string { if p == nil { return "" } return fmt.Sprintf("FirstHealthcheckArgs(%+v)", *p) } // Attributes: // - Success type FirstHealthcheckResult struct { Success *HealthCheckRes `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewFirstHealthcheckResult() *FirstHealthcheckResult { return &FirstHealthcheckResult{} } var FirstHealthcheckResult_Success_DEFAULT *HealthCheckRes func (p *FirstHealthcheckResult) GetSuccess() *HealthCheckRes { if !p.IsSetSuccess() { return FirstHealthcheckResult_Success_DEFAULT } return p.Success } func (p *FirstHealthcheckResult) IsSetSuccess() bool { return p.Success != nil } func (p *FirstHealthcheckResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *FirstHealthcheckResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &HealthCheckRes{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *FirstHealthcheckResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Healthcheck_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *FirstHealthcheckResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *FirstHealthcheckResult) String() string { if p == nil { return "" } return fmt.Sprintf("FirstHealthcheckResult(%+v)", *p) } type FirstAppErrorArgs struct { } func NewFirstAppErrorArgs() *FirstAppErrorArgs { return &FirstAppErrorArgs{} } func (p *FirstAppErrorArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *FirstAppErrorArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("AppError_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *FirstAppErrorArgs) String() string { if p == nil { return "" } return fmt.Sprintf("FirstAppErrorArgs(%+v)", *p) } type FirstAppErrorResult struct { } func NewFirstAppErrorResult() *FirstAppErrorResult { return &FirstAppErrorResult{} } func (p *FirstAppErrorResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *FirstAppErrorResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("AppError_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *FirstAppErrorResult) String() string { if p == nil { return "" } return fmt.Sprintf("FirstAppErrorResult(%+v)", *p) } ================================================ FILE: examples/thrift/gen-go/example/second.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package example import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type Second interface { Test() (err error) } type SecondClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewSecondClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *SecondClient { return &SecondClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewSecondClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *SecondClient { return &SecondClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } func (p *SecondClient) Test() (err error) { if err = p.sendTest(); err != nil { return } return p.recvTest() } func (p *SecondClient) sendTest() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Test", thrift.CALL, p.SeqId); err != nil { return } args := SecondTestArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *SecondClient) recvTest() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Test" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Test failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Test failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error12 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error13 error error13, err = error12.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error13 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Test failed: invalid message type") return } result := SecondTestResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } return } type SecondProcessor struct { processorMap map[string]thrift.TProcessorFunction handler Second } func (p *SecondProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *SecondProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *SecondProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewSecondProcessor(handler Second) *SecondProcessor { self14 := &SecondProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self14.processorMap["Test"] = &secondProcessorTest{handler: handler} return self14 } func (p *SecondProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x15 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x15.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x15 } type secondProcessorTest struct { handler Second } func (p *secondProcessorTest) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := SecondTestArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Test", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := SecondTestResult{} var err2 error if err2 = p.handler.Test(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Test: "+err2.Error()) oprot.WriteMessageBegin("Test", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } if err2 = oprot.WriteMessageBegin("Test", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES type SecondTestArgs struct { } func NewSecondTestArgs() *SecondTestArgs { return &SecondTestArgs{} } func (p *SecondTestArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SecondTestArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Test_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SecondTestArgs) String() string { if p == nil { return "" } return fmt.Sprintf("SecondTestArgs(%+v)", *p) } type SecondTestResult struct { } func NewSecondTestResult() *SecondTestResult { return &SecondTestResult{} } func (p *SecondTestResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SecondTestResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Test_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SecondTestResult) String() string { if p == nil { return "" } return fmt.Sprintf("SecondTestResult(%+v)", *p) } ================================================ FILE: examples/thrift/gen-go/example/tchan-example.go ================================================ // @generated Code generated by thrift-gen. Do not modify. // Package example is generated code used to make or handle TChannel calls using Thrift. package example import ( "fmt" athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "github.com/uber/tchannel-go/thrift" ) // Interfaces for the service and client for the services defined in the IDL. // TChanBase is the interface that defines the server handler and client interface. type TChanBase interface { BaseCall(ctx thrift.Context) error } // TChanFirst is the interface that defines the server handler and client interface. type TChanFirst interface { TChanBase AppError(ctx thrift.Context) error Echo(ctx thrift.Context, msg string) (string, error) Healthcheck(ctx thrift.Context) (*HealthCheckRes, error) } // TChanSecond is the interface that defines the server handler and client interface. type TChanSecond interface { Test(ctx thrift.Context) error } // Implementation of a client and service handler. type tchanBaseClient struct { thriftService string client thrift.TChanClient } func NewTChanBaseInheritedClient(thriftService string, client thrift.TChanClient) *tchanBaseClient { return &tchanBaseClient{ thriftService, client, } } // NewTChanBaseClient creates a client that can be used to make remote calls. func NewTChanBaseClient(client thrift.TChanClient) TChanBase { return NewTChanBaseInheritedClient("Base", client) } func (c *tchanBaseClient) BaseCall(ctx thrift.Context) error { var resp BaseBaseCallResult args := BaseBaseCallArgs{} success, err := c.client.Call(ctx, c.thriftService, "BaseCall", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for BaseCall") } } return err } type tchanBaseServer struct { handler TChanBase } // NewTChanBaseServer wraps a handler for TChanBase so it can be // registered with a thrift.Server. func NewTChanBaseServer(handler TChanBase) thrift.TChanServer { return &tchanBaseServer{ handler, } } func (s *tchanBaseServer) Service() string { return "Base" } func (s *tchanBaseServer) Methods() []string { return []string{ "BaseCall", } } func (s *tchanBaseServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "BaseCall": return s.handleBaseCall(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanBaseServer) handleBaseCall(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req BaseBaseCallArgs var res BaseBaseCallResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.BaseCall(ctx) if err != nil { return false, nil, err } else { } return err == nil, &res, nil } type tchanFirstClient struct { TChanBase thriftService string client thrift.TChanClient } func NewTChanFirstInheritedClient(thriftService string, client thrift.TChanClient) *tchanFirstClient { return &tchanFirstClient{ NewTChanBaseInheritedClient(thriftService, client), thriftService, client, } } // NewTChanFirstClient creates a client that can be used to make remote calls. func NewTChanFirstClient(client thrift.TChanClient) TChanFirst { return NewTChanFirstInheritedClient("First", client) } func (c *tchanFirstClient) AppError(ctx thrift.Context) error { var resp FirstAppErrorResult args := FirstAppErrorArgs{} success, err := c.client.Call(ctx, c.thriftService, "AppError", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for AppError") } } return err } func (c *tchanFirstClient) Echo(ctx thrift.Context, msg string) (string, error) { var resp FirstEchoResult args := FirstEchoArgs{ Msg: msg, } success, err := c.client.Call(ctx, c.thriftService, "Echo", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for Echo") } } return resp.GetSuccess(), err } func (c *tchanFirstClient) Healthcheck(ctx thrift.Context) (*HealthCheckRes, error) { var resp FirstHealthcheckResult args := FirstHealthcheckArgs{} success, err := c.client.Call(ctx, c.thriftService, "Healthcheck", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for Healthcheck") } } return resp.GetSuccess(), err } type tchanFirstServer struct { thrift.TChanServer handler TChanFirst } // NewTChanFirstServer wraps a handler for TChanFirst so it can be // registered with a thrift.Server. func NewTChanFirstServer(handler TChanFirst) thrift.TChanServer { return &tchanFirstServer{ NewTChanBaseServer(handler), handler, } } func (s *tchanFirstServer) Service() string { return "First" } func (s *tchanFirstServer) Methods() []string { return []string{ "AppError", "Echo", "Healthcheck", "BaseCall", } } func (s *tchanFirstServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "AppError": return s.handleAppError(ctx, protocol) case "Echo": return s.handleEcho(ctx, protocol) case "Healthcheck": return s.handleHealthcheck(ctx, protocol) case "BaseCall": return s.TChanServer.Handle(ctx, methodName, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanFirstServer) handleAppError(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req FirstAppErrorArgs var res FirstAppErrorResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.AppError(ctx) if err != nil { return false, nil, err } else { } return err == nil, &res, nil } func (s *tchanFirstServer) handleEcho(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req FirstEchoArgs var res FirstEchoResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Echo(ctx, req.Msg) if err != nil { return false, nil, err } else { res.Success = &r } return err == nil, &res, nil } func (s *tchanFirstServer) handleHealthcheck(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req FirstHealthcheckArgs var res FirstHealthcheckResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Healthcheck(ctx) if err != nil { return false, nil, err } else { res.Success = r } return err == nil, &res, nil } type tchanSecondClient struct { thriftService string client thrift.TChanClient } func NewTChanSecondInheritedClient(thriftService string, client thrift.TChanClient) *tchanSecondClient { return &tchanSecondClient{ thriftService, client, } } // NewTChanSecondClient creates a client that can be used to make remote calls. func NewTChanSecondClient(client thrift.TChanClient) TChanSecond { return NewTChanSecondInheritedClient("Second", client) } func (c *tchanSecondClient) Test(ctx thrift.Context) error { var resp SecondTestResult args := SecondTestArgs{} success, err := c.client.Call(ctx, c.thriftService, "Test", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for Test") } } return err } type tchanSecondServer struct { handler TChanSecond } // NewTChanSecondServer wraps a handler for TChanSecond so it can be // registered with a thrift.Server. func NewTChanSecondServer(handler TChanSecond) thrift.TChanServer { return &tchanSecondServer{ handler, } } func (s *tchanSecondServer) Service() string { return "Second" } func (s *tchanSecondServer) Methods() []string { return []string{ "Test", } } func (s *tchanSecondServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "Test": return s.handleTest(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanSecondServer) handleTest(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req SecondTestArgs var res SecondTestResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.Test(ctx) if err != nil { return false, nil, err } else { } return err == nil, &res, nil } ================================================ FILE: examples/thrift/gen-go/example/ttypes.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package example import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal var GoUnusedProtection__ int // Attributes: // - Healthy // - Msg type HealthCheckRes struct { Healthy bool `thrift:"healthy,1" db:"healthy" json:"healthy"` Msg string `thrift:"msg,2" db:"msg" json:"msg"` } func NewHealthCheckRes() *HealthCheckRes { return &HealthCheckRes{} } func (p *HealthCheckRes) GetHealthy() bool { return p.Healthy } func (p *HealthCheckRes) GetMsg() string { return p.Msg } func (p *HealthCheckRes) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } case 2: if err := p.ReadField2(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *HealthCheckRes) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadBool(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Healthy = v } return nil } func (p *HealthCheckRes) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.Msg = v } return nil } func (p *HealthCheckRes) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("HealthCheckRes"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *HealthCheckRes) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("healthy", thrift.BOOL, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:healthy: ", p), err) } if err := oprot.WriteBool(bool(p.Healthy)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.healthy (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:healthy: ", p), err) } return err } func (p *HealthCheckRes) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("msg", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:msg: ", p), err) } if err := oprot.WriteString(string(p.Msg)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.msg (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:msg: ", p), err) } return err } func (p *HealthCheckRes) String() string { if p == nil { return "" } return fmt.Sprintf("HealthCheckRes(%+v)", *p) } ================================================ FILE: examples/thrift/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "bufio" "errors" "fmt" "log" "net" "os" "runtime" "strings" "time" tchannel "github.com/uber/tchannel-go" gen "github.com/uber/tchannel-go/examples/thrift/gen-go/example" "github.com/uber/tchannel-go/thrift" ) func main() { var ( listener net.Listener err error ) if listener, err = setupServer(); err != nil { log.Fatalf("setupServer failed: %v", err) } if err := runClient1("server", listener.Addr()); err != nil { log.Fatalf("runClient1 failed: %v", err) } if err := runClient2("server", listener.Addr()); err != nil { log.Fatalf("runClient2 failed: %v", err) } go listenConsole() // Run for 10 seconds, then stop time.Sleep(time.Second * 10) } func setupServer() (net.Listener, error) { tchan, err := tchannel.NewChannel("server", optsFor("server")) if err != nil { return nil, err } listener, err := net.Listen("tcp", ":0") if err != nil { return nil, err } server := thrift.NewServer(tchan) server.Register(gen.NewTChanFirstServer(&firstHandler{})) server.Register(gen.NewTChanSecondServer(&secondHandler{})) // Serve will set the local peer info, and start accepting sockets in a separate goroutine. tchan.Serve(listener) return listener, nil } func runClient1(hyperbahnService string, addr net.Addr) error { tchan, err := tchannel.NewChannel("client1", optsFor("client1")) if err != nil { return err } tchan.Peers().Add(addr.String()) tclient := thrift.NewClient(tchan, hyperbahnService, nil) client := gen.NewTChanFirstClient(tclient) go func() { for { ctx, cancel := thrift.NewContext(time.Second) res, err := client.Echo(ctx, "Hi") log.Println("Echo(Hi) = ", res, ", err: ", err) log.Println("AppError() = ", client.AppError(ctx)) log.Println("BaseCall() = ", client.BaseCall(ctx)) cancel() time.Sleep(100 * time.Millisecond) } }() return nil } func runClient2(hyperbahnService string, addr net.Addr) error { tchan, err := tchannel.NewChannel("client2", optsFor("client2")) if err != nil { return err } tchan.Peers().Add(addr.String()) tclient := thrift.NewClient(tchan, hyperbahnService, nil) client := gen.NewTChanSecondClient(tclient) go func() { for { ctx, cancel := thrift.NewContext(time.Second) client.Test(ctx) cancel() time.Sleep(100 * time.Millisecond) } }() return nil } func listenConsole() { rdr := bufio.NewReader(os.Stdin) for { line, _ := rdr.ReadString('\n') switch strings.TrimSpace(line) { case "s": printStack() default: fmt.Println("Unrecognized command:", line) } } } func printStack() { buf := make([]byte, 10000) runtime.Stack(buf, true /* all */) fmt.Println("Stack:\n", string(buf)) } type firstHandler struct{} func (h *firstHandler) Healthcheck(ctx thrift.Context) (*gen.HealthCheckRes, error) { log.Printf("first: HealthCheck()\n") return &gen.HealthCheckRes{ Healthy: true, Msg: "OK"}, nil } func (h *firstHandler) BaseCall(ctx thrift.Context) error { log.Printf("first: BaseCall()\n") return nil } func (h *firstHandler) Echo(ctx thrift.Context, msg string) (r string, err error) { log.Printf("first: Echo(%v)\n", msg) return msg, nil } func (h *firstHandler) AppError(ctx thrift.Context) error { log.Printf("first: AppError()\n") return errors.New("app error") } func (h *firstHandler) OneWay(ctx thrift.Context) error { log.Printf("first: OneWay()\n") return errors.New("OneWay error...won't be seen by client") } type secondHandler struct{} func (h *secondHandler) Test(ctx thrift.Context) error { log.Println("secondHandler: Test()") return nil } func optsFor(processName string) *tchannel.ChannelOptions { return &tchannel.ChannelOptions{ ProcessName: processName, Logger: tchannel.NewLevelLogger(tchannel.SimpleLogger, tchannel.LogLevelWarn), } } ================================================ FILE: fragmentation_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "io" "io/ioutil" "sync" "testing" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const ( testFragmentHeaderSize = 1 /* flags */ + 1 /* checksum type */ + 4 /* CRC32 checksum */ testFragmentPayloadSize = 10 // enough room for a small payload testFragmentSize = testFragmentHeaderSize + testFragmentPayloadSize ) func TestFragmentationEmptyArgs(t *testing.T) { runFragmentationTest(t, []string{"", "", ""}, buffers([][]byte{{ 0x0000, // flags byte(ChecksumTypeCrc32), 0x0000, 0x0000, 0x0000, 0x0000, // empty checksum 0x0000, 0x0000, // arg 1 (length no body) 0x0000, 0x0000, // arg 2 (length no body) 0x0000, 0x0000, // arg 3 (length no body) }})) } func TestFragmentationSingleFragment(t *testing.T) { runFragmentationTest(t, []string{"A", "B", "C"}, buffers([][]byte{{ 0x0000, // flags byte(ChecksumTypeCrc32), 0xa3, 0x83, 0x3, 0x48, // CRC32 checksum 0x0000, 0x0001, 'A', // arg 1 (length single character body) 0x0000, 0x0001, 'B', // arg 2 (length single character body) 0x0000, 0x0001, 'C', // arg 3 (length single character body) }})) } func TestFragmentationMultipleFragments(t *testing.T) { runFragmentationTest(t, []string{"ABCDEFHIJKLM", "NOPQRZTUWXYZ", "012345678"}, buffers( [][]byte{{ 0x0001, // has more fragments byte(ChecksumTypeCrc32), 0x98, 0x43, 0x9a, 0x45, // checksum 0x0000, 0x0008, 'A', 'B', 'C', 'D', 'E', 'F', 'H', 'I'}}, // first 8 bytes of arg 1 [][]byte{{ 0x0001, // has more fragments byte(ChecksumTypeCrc32), 0xaf, 0xb9, 0x9c, 0x98, // checksum 0x0000, 0x0004, 'J', 'K', 'L', 'M', // remaining 4 bytes of arg 1 0x0000, 0x0002, 'N', 'O'}}, // all of arg 2 that fits (2 bytes) [][]byte{{ 0x0001, // has more fragments byte(ChecksumTypeCrc32), 0x23, 0xae, 0x2f, 0x37, // checksum 0x0000, 0x0008, 'P', 'Q', 'R', 'Z', 'T', 'U', 'W', 'X'}}, // more aarg 2 [][]byte{{ 0x0001, // has more fragments byte(ChecksumTypeCrc32), 0xa2, 0x93, 0x74, 0xd8, // checksum 0x0000, 0x0002, 'Y', 'Z', // last parts of arg 2 0x0000, 0x0004, '0', '1', '2', '3'}}, // first parts of arg 3 [][]byte{{ 0x0000, // no more fragments byte(ChecksumTypeCrc32), 0xf3, 0x29, 0xbb, 0xd1, // checksum 0x0000, 0x0005, '4', '5', '6', '7', '8'}}, )) } func TestFragmentationMiddleArgNearFragmentBoundary(t *testing.T) { // This covers the case where an argument in the middle ends near the // end of a fragment boundary, such that there is not enough room to // put another argument in the fragment. In this case there should be // an empty chunk for that argument in the next fragment runFragmentationTest(t, []string{"ABCDEF", "NOPQ"}, buffers( [][]byte{{ 0x0001, // has more fragments byte(ChecksumTypeCrc32), 0xbb, 0x76, 0xfe, 0x69, // CRC32 checksum 0x0000, 0x0006, 'A', 'B', 'C', 'D', 'E', 'F'}}, // all of arg 1 [][]byte{{ 0x0000, // no more fragments byte(ChecksumTypeCrc32), 0x5b, 0x3c, 0x54, 0xfe, // CRC32 checksum 0x0000, 0x0000, // empty chunk indicating the end of arg 1 0x0000, 0x0004, 'N', 'O', 'P', 'Q'}}, // all of arg 2 )) } func TestFragmentationMiddleArgOnExactFragmentBoundary(t *testing.T) { // This covers the case where an argument in the middle ends exactly at the end of a fragment. // Again, there should be an empty chunk for that argument in the next fragment runFragmentationTest(t, []string{"ABCDEFGH", "NOPQ"}, buffers( [][]byte{{ 0x0001, // has more fragments byte(ChecksumTypeCrc32), 0x68, 0xdc, 0xb6, 0x1c, // CRC32 checksum 0x0000, 0x0008, 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'}}, // all of arg 1 [][]byte{{ 0x0000, // no more fragments byte(ChecksumTypeCrc32), 0x32, 0x66, 0xf, 0x25, // CRC32 checksum 0x0000, 0x0000, // empty chunk indicating the end of arg 1 0x0000, 0x0004, 'N', 'O', 'P', 'Q'}}, // all of arg 2 )) } func TestFragmentationLastArgOnNearFragmentBoundary(t *testing.T) { // Covers the case where the last argument ends near a fragment // boundary. No new fragments should get created runFragmentationTest(t, []string{"ABCDEF"}, buffers( [][]byte{{ 0x0000, // has more fragments byte(ChecksumTypeCrc32), 0xbb, 0x76, 0xfe, 0x69, // CRC32 checksum 0x0000, 0x0006, 'A', 'B', 'C', 'D', 'E', 'F'}}, // all of arg 1 )) } func TestFragmentationLastArgOnExactFragmentBoundary(t *testing.T) { // Covers the case where the last argument ends exactly on a fragment // boundary. No new fragments should get created runFragmentationTest(t, []string{"ABCDEFGH"}, buffers( [][]byte{{ 0x0000, // has more fragments byte(ChecksumTypeCrc32), 0x68, 0xdc, 0xb6, 0x1c, // CRC32 checksum 0x0000, 0x0008, 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'}}, // all of arg 1 )) } func TestFragmentationWriterErrors(t *testing.T) { runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Write without starting argument _, err := w.Write([]byte("foo")) assert.Error(t, err) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // BeginArgument twice without starting argument assert.NoError(t, w.BeginArgument(false /* last */)) assert.Error(t, w.BeginArgument(false /* last */)) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // BeginArgument after writing final argument writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello"))) assert.Error(t, w.BeginArgument(false /* last */)) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Close without beginning argument assert.Error(t, w.Close()) }) } func TestFragmentationReaderErrors(t *testing.T) { runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Read without starting argument b := make([]byte, 10) _, err := r.Read(b) assert.Error(t, err) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Close without beginning argument assert.Error(t, r.Close()) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // BeginArgument after reading final argument writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello"))) reader, err := r.ArgReader(true /* last */) assert.NoError(t, err) var arg []byte assert.NoError(t, NewArgReader(reader, nil).Read(&arg)) assert.Equal(t, "hello", string(arg)) assert.Error(t, r.BeginArgument(false /* last */)) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Sender sent final argument, but receiver thinks there is more writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello"))) reader, err := r.ArgReader(false /* last */) assert.NoError(t, err) var arg []byte assert.Error(t, NewArgReader(reader, nil).Read(&arg)) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Close without receiving all data in chunk writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello"))) assert.NoError(t, r.BeginArgument(true /* last */)) b := make([]byte, 3) _, err = r.Read(b) assert.NoError(t, err) assert.Equal(t, "hel", string(b)) assert.Error(t, r.Close()) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // Close without receiving all fragments writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello world what's up"))) assert.NoError(t, r.BeginArgument(true /* last */)) b := make([]byte, 8) _, err = r.Read(b) assert.NoError(t, err) assert.Equal(t, "hello wo", string(b)) assert.Error(t, r.Close()) }) runFragmentationErrorTest(func(w *fragmentingWriter, r *fragmentingReader) { // BeginArgument while argument is in process writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello world what's up"))) assert.NoError(t, r.BeginArgument(false /* last */)) assert.Error(t, r.BeginArgument(false /* last */)) }) } func TestFragmentationChecksumTypeErrors(t *testing.T) { sendCh := make(fragmentChannel, 10) recvCh := make(fragmentChannel, 10) w := newFragmentingWriter(NullLogger, sendCh, ChecksumTypeCrc32.New()) r := newFragmentingReader(NullLogger, recvCh) // Write two fragments out writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello world what's up"))) // Intercept and change the checksum type between the first and second fragment first := <-sendCh recvCh <- first second := <-sendCh second[1] = byte(ChecksumTypeCrc32C) recvCh <- second // Attempt to read, should fail reader, err := r.ArgReader(true /* last */) assert.NoError(t, err) var arg []byte assert.Error(t, NewArgReader(reader, nil).Read(&arg)) } func TestFragmentationChecksumMismatch(t *testing.T) { sendCh := make(fragmentChannel, 10) recvCh := make(fragmentChannel, 10) w := newFragmentingWriter(NullLogger, sendCh, ChecksumTypeCrc32.New()) r := newFragmentingReader(NullLogger, recvCh) // Write two fragments out writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) assert.NoError(t, NewArgWriter(writer, nil).Write([]byte("hello world this is two"))) // Intercept and change the checksum value in the second fragment first := <-sendCh recvCh <- first second := <-sendCh second[2], second[3], second[4], second[5] = 0x01, 0x02, 0x03, 0x04 recvCh <- second // Attempt to read, should fail due to mismatch between local checksum and peer supplied checksum reader, err := r.ArgReader(true /* last */) assert.NoError(t, err) _, err = io.Copy(ioutil.Discard, reader) assert.Equal(t, errMismatchedChecksums, err) } func runFragmentationErrorTest(f func(w *fragmentingWriter, r *fragmentingReader)) { ch := make(fragmentChannel, 10) w := newFragmentingWriter(NullLogger, ch, ChecksumTypeCrc32.New()) r := newFragmentingReader(NullLogger, ch) f(w, r) } func runFragmentationTest(t *testing.T, args []string, expectedFragments [][]byte) { sendCh := make(fragmentChannel, 10) recvCh := make(fragmentChannel, 10) w := newFragmentingWriter(NullLogger, sendCh, ChecksumTypeCrc32.New()) r := newFragmentingReader(NullLogger, recvCh) var fragments [][]byte var actualArgs []string var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for fragment := range sendCh { fragments = append(fragments, fragment) recvCh <- fragment } }() wg.Add(1) go func() { defer wg.Done() for i := 0; i < len(args)-1; i++ { reader, err := r.ArgReader(false /* last */) require.NoError(t, err) var arg []byte require.NoError(t, NewArgReader(reader, nil).Read(&arg)) actualArgs = append(actualArgs, string(arg)) } reader, err := r.ArgReader(true /* last */) require.NoError(t, err) var arg []byte require.NoError(t, NewArgReader(reader, nil).Read(&arg)) actualArgs = append(actualArgs, string(arg)) }() for i := 0; i < len(args)-1; i++ { writer, err := w.ArgWriter(false /* last */) assert.NoError(t, err) require.NoError(t, NewArgWriter(writer, nil).Write([]byte(args[i]))) } writer, err := w.ArgWriter(true /* last */) assert.NoError(t, err) require.NoError(t, NewArgWriter(writer, nil).Write([]byte(args[len(args)-1]))) close(sendCh) wg.Wait() assert.Equal(t, args, actualArgs) assert.Equal(t, len(expectedFragments), len(fragments), "incorrect number of fragments") for i := 0; i < len(expectedFragments); i++ { expectedFragment, fragment := expectedFragments[i], fragments[i] assert.Equal(t, expectedFragment, fragment, "incorrect fragment %d", i) } } type fragmentChannel chan []byte func (ch fragmentChannel) newFragment(initial bool, checksum Checksum) (*writableFragment, error) { wbuf := typed.NewWriteBuffer(make([]byte, testFragmentSize)) fragment := new(writableFragment) fragment.flagsRef = wbuf.DeferByte() wbuf.WriteSingleByte(byte(checksum.TypeCode())) fragment.checksumRef = wbuf.DeferBytes(checksum.Size()) fragment.checksum = checksum fragment.contents = wbuf return fragment, wbuf.Err() } func (ch fragmentChannel) flushFragment(fragment *writableFragment) error { var buf bytes.Buffer fragment.contents.FlushTo(&buf) ch <- buf.Bytes() return nil } func (ch fragmentChannel) recvNextFragment(initial bool) (*readableFragment, error) { rbuf := typed.NewReadBuffer(<-ch) fragment := new(readableFragment) fragment.onDone = func() {} fragment.flags = rbuf.ReadSingleByte() fragment.checksumType = ChecksumType(rbuf.ReadSingleByte()) fragment.checksum = rbuf.ReadBytes(fragment.checksumType.ChecksumSize()) fragment.contents = rbuf return fragment, rbuf.Err() } func (ch fragmentChannel) doneReading(unexpected error) {} func (ch fragmentChannel) doneSending() {} func buffers(elements ...[][]byte) [][]byte { var buffers [][]byte for i := range elements { buffers = append(buffers, bytes.Join(elements[i], []byte{})) } return buffers } ================================================ FILE: fragmenting_reader.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "errors" "io" "github.com/uber/tchannel-go/typed" ) var ( errMismatchedChecksumTypes = errors.New("peer returned different checksum types between fragments") errMismatchedChecksums = errors.New("different checksums between peer and local") errChunkExceedsFragmentSize = errors.New("peer chunk size exceeds remaining data in fragment") errAlreadyReadingArgument = errors.New("already reading argument") errNotReadingArgument = errors.New("not reading argument") errMoreDataInArgument = errors.New("closed argument reader when there is more data available to read") errExpectedMoreArguments = errors.New("closed argument reader when there may be more data available to read") errNoMoreFragments = errors.New("no more fragments") ) type readableFragment struct { isDone bool flags byte checksumType ChecksumType checksum []byte contents *typed.ReadBuffer onDone func() } func (f *readableFragment) done() { if f.isDone { return } f.onDone() f.isDone = true } type fragmentReceiver interface { // recvNextFragment returns the next received fragment, blocking until // it's available or a deadline/cancel occurs recvNextFragment(intial bool) (*readableFragment, error) // doneReading is called when the fragment receiver is finished reading all fragments. // If an error frame is the last received frame, then doneReading is called with an error. doneReading(unexpectedErr error) } type fragmentingReadState int const ( fragmentingReadStart fragmentingReadState = iota fragmentingReadInArgument fragmentingReadInLastArgument fragmentingReadWaitingForArgument fragmentingReadComplete ) func (s fragmentingReadState) isReadingArgument() bool { return s == fragmentingReadInArgument || s == fragmentingReadInLastArgument } type fragmentingReader struct { logger Logger state fragmentingReadState remainingChunks [][]byte curChunk []byte hasMoreFragments bool receiver fragmentReceiver curFragment *readableFragment checksum Checksum err error } func newFragmentingReader(logger Logger, receiver fragmentReceiver) *fragmentingReader { return &fragmentingReader{ logger: logger, receiver: receiver, hasMoreFragments: true, } } // The ArgReader will handle fragmentation as needed. Once the argument has // been read, the ArgReader must be closed. func (r *fragmentingReader) ArgReader(last bool) (ArgReader, error) { if err := r.BeginArgument(last); err != nil { return nil, err } return r, nil } func (r *fragmentingReader) BeginArgument(last bool) error { if r.err != nil { return r.err } switch { case r.state.isReadingArgument(): r.err = errAlreadyReadingArgument return r.err case r.state == fragmentingReadComplete: r.err = errComplete return r.err } // We're guaranteed that either this is the first argument (in which // case we need to get the first fragment and chunk) or that we have a // valid curChunk (populated via Close) if r.state == fragmentingReadStart { if r.err = r.recvAndParseNextFragment(true); r.err != nil { return r.err } } r.state = fragmentingReadInArgument if last { r.state = fragmentingReadInLastArgument } return nil } func (r *fragmentingReader) Read(b []byte) (int, error) { if r.err != nil { return 0, r.err } if !r.state.isReadingArgument() { r.err = errNotReadingArgument return 0, r.err } totalRead := 0 for { // Copy as much data as we can from the current chunk n := copy(b, r.curChunk) totalRead += n r.curChunk = r.curChunk[n:] b = b[n:] if len(b) == 0 { // There was enough data in the current chunk to // satisfy the read. Advance our place in the current // chunk and be done return totalRead, nil } // There wasn't enough data in the current chunk to satisfy the // current read. If there are more chunks in the current // fragment, then we've reach the end of this argument. Return // an io.EOF so functions like ioutil.ReadFully know to finish if len(r.remainingChunks) > 0 { return totalRead, io.EOF } // Try to fetch more fragments. If there are no more // fragments, then we've reached the end of the argument if !r.hasMoreFragments { return totalRead, io.EOF } if r.err = r.recvAndParseNextFragment(false); r.err != nil { return totalRead, r.err } } } func (r *fragmentingReader) Close() error { last := r.state == fragmentingReadInLastArgument if r.err != nil { return r.err } if !r.state.isReadingArgument() { r.err = errNotReadingArgument return r.err } if len(r.curChunk) > 0 { // There was more data remaining in the chunk r.err = errMoreDataInArgument return r.err } // Several possibilities here: // 1. The caller thinks this is the last argument, but there are chunks in the current // fragment or more fragments in this message // - give them an error // 2. The caller thinks this is the last argument, and there are no more chunks and no more // fragments // - the stream is complete // 3. The caller thinks there are more arguments, and there are more chunks in this fragment // - advance to the next chunk, this is the first chunk for the next argument // 4. The caller thinks there are more arguments, and there are no more chunks in this fragment, // but there are more fragments in the message // - retrieve the next fragment, confirm it has an empty chunk (indicating the end of the // current argument), advance to the next check (which is the first chunk for the next arg) // 5. The caller thinks there are more arguments, but there are no more chunks or fragments available // - give them an err if last { if len(r.remainingChunks) > 0 || r.hasMoreFragments { // We expect more arguments r.err = errExpectedMoreArguments return r.err } r.doneReading(nil) r.curFragment.done() r.curChunk = nil r.state = fragmentingReadComplete return nil } r.state = fragmentingReadWaitingForArgument // If there are more chunks in this fragment, advance to the next chunk. This is the first chunk // for the next argument if len(r.remainingChunks) > 0 { r.curChunk, r.remainingChunks = r.remainingChunks[0], r.remainingChunks[1:] return nil } // If there are no more chunks in this fragment, and no more fragments, we have an issue if !r.hasMoreFragments { r.err = errNoMoreFragments return r.err } // There are no more chunks in this fragments, but more fragments - get the next fragment if r.err = r.recvAndParseNextFragment(false); r.err != nil { return r.err } return nil } func (r *fragmentingReader) recvAndParseNextFragment(initial bool) error { if r.err != nil { return r.err } if r.curFragment != nil { r.curFragment.done() } r.curFragment, r.err = r.receiver.recvNextFragment(initial) if r.err != nil { if err, ok := r.err.(errorMessage); ok { // Serialized system errors are still reported (e.g. latency, trace reporting). r.err = err.AsSystemError() r.doneReading(r.err) } return r.err } // Set checksum, or confirm new checksum is the same type as the prior checksum if r.checksum == nil { r.checksum = r.curFragment.checksumType.New() } else if r.checksum.TypeCode() != r.curFragment.checksumType { return errMismatchedChecksumTypes } // Split fragment into underlying chunks r.hasMoreFragments = (r.curFragment.flags & hasMoreFragmentsFlag) == hasMoreFragmentsFlag r.remainingChunks = nil for r.curFragment.contents.BytesRemaining() > 0 && r.curFragment.contents.Err() == nil { chunkSize := r.curFragment.contents.ReadUint16() if chunkSize > uint16(r.curFragment.contents.BytesRemaining()) { return errChunkExceedsFragmentSize } chunkData := r.curFragment.contents.ReadBytes(int(chunkSize)) r.remainingChunks = append(r.remainingChunks, chunkData) r.checksum.Add(chunkData) } if r.curFragment.contents.Err() != nil { return r.curFragment.contents.Err() } // Validate checksums localChecksum := r.checksum.Sum() if bytes.Compare(r.curFragment.checksum, localChecksum) != 0 { r.err = errMismatchedChecksums return r.err } // Pull out the first chunk to act as the current chunk r.curChunk, r.remainingChunks = r.remainingChunks[0], r.remainingChunks[1:] return nil } func (r *fragmentingReader) doneReading(err error) { if r.checksum != nil { r.checksum.Release() } r.receiver.doneReading(err) } ================================================ FILE: fragmenting_writer.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "errors" "fmt" "github.com/uber/tchannel-go/typed" ) var ( errAlreadyWritingArgument = errors.New("already writing argument") errNotWritingArgument = errors.New("not writing argument") errComplete = errors.New("last argument already sent") ) const ( chunkHeaderSize = 2 // each chunk is a uint16 hasMoreFragmentsFlag = 0x01 // flags indicating there are more fragments coming ) // A writableFragment is a fragment that can be written to, containing a buffer // for contents, a running checksum, and placeholders for the fragment flags // and final checksum value type writableFragment struct { flagsRef typed.ByteRef checksumRef typed.BytesRef checksum Checksum contents *typed.WriteBuffer frame *Frame } // finish finishes the fragment, updating the final checksum and fragment flags func (f *writableFragment) finish(hasMoreFragments bool) { f.checksumRef.Update(f.checksum.Sum()) if hasMoreFragments { // Important: hasMoreFragmentsFlag is set if there are more fragments, but NOT CLEARED if there aren't. // This allows for callReqContinue frames to follow a fragmented callReq frame e.g. when arg2 is modified // by the relayer f.flagsRef.Update(hasMoreFragmentsFlag) } else { f.checksum.Release() } } // A writableChunk is a chunk of data within a fragment, representing the // contents of an argument within that fragment type writableChunk struct { size uint16 sizeRef typed.Uint16Ref checksum Checksum contents *typed.WriteBuffer } // newWritableChunk creates a new writable chunk around a checksum and a buffer to hold data func newWritableChunk(checksum Checksum, contents *typed.WriteBuffer) *writableChunk { return &writableChunk{ size: 0, sizeRef: contents.DeferUint16(), checksum: checksum, contents: contents, } } // writeAsFits writes as many bytes from the given slice as fits into the chunk func (c *writableChunk) writeAsFits(b []byte) int { if len(b) > c.contents.BytesRemaining() { b = b[:c.contents.BytesRemaining()] } c.checksum.Add(b) c.contents.WriteBytes(b) written := len(b) c.size += uint16(written) return written } // finish finishes the chunk, updating its chunk size func (c *writableChunk) finish() { c.sizeRef.Update(c.size) } // A fragmentSender allocates and sends outbound fragments to a target type fragmentSender interface { // newFragment allocates a new fragment newFragment(initial bool, checksum Checksum) (*writableFragment, error) // flushFragment flushes the given fragment flushFragment(f *writableFragment) error // doneSending is called when the fragment receiver is finished sending all fragments. doneSending() } type fragmentingWriterState int const ( fragmentingWriteStart fragmentingWriterState = iota fragmentingWriteInArgument fragmentingWriteInLastArgument fragmentingWriteWaitingForArgument fragmentingWriteComplete ) func (s fragmentingWriterState) isWritingArgument() bool { return s == fragmentingWriteInArgument || s == fragmentingWriteInLastArgument } // A fragmentingWriter writes one or more arguments to an underlying stream, // breaking them into fragments as needed, and applying an overarching // checksum. It relies on an underlying fragmentSender, which creates and // flushes the fragments as needed type fragmentingWriter struct { logger Logger sender fragmentSender checksum Checksum curFragment *writableFragment curChunk *writableChunk state fragmentingWriterState err error } // newFragmentingWriter creates a new fragmenting writer func newFragmentingWriter(logger Logger, sender fragmentSender, checksum Checksum) *fragmentingWriter { return &fragmentingWriter{ logger: logger, sender: sender, checksum: checksum, state: fragmentingWriteStart, } } // ArgWriter returns an ArgWriter to write an argument. The ArgWriter will handle // fragmentation as needed. Once the argument is written, the ArgWriter must be closed. func (w *fragmentingWriter) ArgWriter(last bool) (ArgWriter, error) { if err := w.BeginArgument(last); err != nil { return nil, err } return w, nil } // BeginArgument tells the writer that the caller is starting a new argument. // Must not be called while an existing argument is in place func (w *fragmentingWriter) BeginArgument(last bool) error { if w.err != nil { return w.err } switch { case w.state == fragmentingWriteComplete: w.err = errComplete return w.err case w.state.isWritingArgument(): w.err = errAlreadyWritingArgument return w.err } // If we don't have a fragment, request one if w.curFragment == nil { initial := w.state == fragmentingWriteStart if w.curFragment, w.err = w.sender.newFragment(initial, w.checksum); w.err != nil { return w.err } } // If there's no room in the current fragment, freak out. This will // only happen due to an implementation error in the TChannel stack // itself if w.curFragment.contents.BytesRemaining() <= chunkHeaderSize { panic(fmt.Errorf("attempting to begin an argument in a fragment with only %d bytes available", w.curFragment.contents.BytesRemaining())) } w.curChunk = newWritableChunk(w.checksum, w.curFragment.contents) w.state = fragmentingWriteInArgument if last { w.state = fragmentingWriteInLastArgument } return nil } // Write writes argument data, breaking it into fragments as needed func (w *fragmentingWriter) Write(b []byte) (int, error) { if w.err != nil { return 0, w.err } if !w.state.isWritingArgument() { w.err = errNotWritingArgument return 0, w.err } totalWritten := 0 for { bytesWritten := w.curChunk.writeAsFits(b) totalWritten += bytesWritten if bytesWritten == len(b) { // The whole thing fit, we're done return totalWritten, nil } // There was more data than fit into the fragment, so flush the current fragment, // start a new fragment and chunk, and continue writing if w.err = w.Flush(); w.err != nil { return totalWritten, w.err } b = b[bytesWritten:] } } // Flush flushes the current fragment, and starts a new fragment and chunk. func (w *fragmentingWriter) Flush() error { w.curChunk.finish() w.curFragment.finish(true) if w.err = w.sender.flushFragment(w.curFragment); w.err != nil { return w.err } if w.curFragment, w.err = w.sender.newFragment(false, w.checksum); w.err != nil { return w.err } w.curChunk = newWritableChunk(w.checksum, w.curFragment.contents) return nil } // Close ends the current argument. func (w *fragmentingWriter) Close() error { last := w.state == fragmentingWriteInLastArgument if w.err != nil { return w.err } if !w.state.isWritingArgument() { w.err = errNotWritingArgument return w.err } w.curChunk.finish() // There are three possibilities here: // 1. There are no more arguments // flush with more_fragments=false, mark the stream as complete // 2. There are more arguments, but we can't fit more data into this fragment // flush with more_fragments=true, start new fragment, write empty chunk to indicate // the current argument is complete // 3. There are more arguments, and we can fit more data into this fragment // update the chunk but leave the current fragment open if last { // No more arguments - flush this final fragment and mark ourselves complete w.state = fragmentingWriteComplete w.curFragment.finish(false) w.err = w.sender.flushFragment(w.curFragment) w.sender.doneSending() return w.err } w.state = fragmentingWriteWaitingForArgument if w.curFragment.contents.BytesRemaining() > chunkHeaderSize { // There's enough room in this fragment for the next argument's // initial chunk, so we're done here return nil } // This fragment is full - flush and prepare for another argument w.curFragment.finish(true) if w.err = w.sender.flushFragment(w.curFragment); w.err != nil { return w.err } if w.curFragment, w.err = w.sender.newFragment(false, w.checksum); w.err != nil { return w.err } // Write an empty chunk to indicate this argument has ended w.curFragment.contents.WriteUint16(0) return nil } ================================================ FILE: frame.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "encoding/json" "fmt" "io" "math" "github.com/uber/tchannel-go/typed" ) const ( // MaxFrameSize is the total maximum size for a frame MaxFrameSize = math.MaxUint16 // FrameHeaderSize is the size of the header element for a frame FrameHeaderSize = 16 // MaxFramePayloadSize is the maximum size of the payload for a single frame MaxFramePayloadSize = MaxFrameSize - FrameHeaderSize ) // FrameHeader is the header for a frame, containing the MessageType and size type FrameHeader struct { // The size of the frame including the header size uint16 // The type of message represented by the frame messageType messageType // Left empty reserved1 byte // The id of the message represented by the frame ID uint32 // Left empty reserved [8]byte } // SetPayloadSize sets the size of the frame payload func (fh *FrameHeader) SetPayloadSize(size uint16) { fh.size = size + FrameHeaderSize } // PayloadSize returns the size of the frame payload func (fh FrameHeader) PayloadSize() uint16 { return fh.size - FrameHeaderSize } // FrameSize returns the total size of the frame func (fh FrameHeader) FrameSize() uint16 { return fh.size } // MessageType returns the type of the message func (fh FrameHeader) MessageType() byte { return byte(fh.messageType) } func (fh FrameHeader) String() string { return fmt.Sprintf("%v[%d]", fh.messageType, fh.ID) } // MarshalJSON returns a `{"id":NNN, "msgType":MMM, "size":SSS}` representation func (fh FrameHeader) MarshalJSON() ([]byte, error) { s := struct { ID uint32 `json:"id"` MsgType messageType `json:"msgType"` Size uint16 `json:"size"` }{fh.ID, fh.messageType, fh.size} return json.Marshal(s) } func (fh *FrameHeader) read(r *typed.ReadBuffer) error { fh.size = r.ReadUint16() fh.messageType = messageType(r.ReadSingleByte()) fh.reserved1 = r.ReadSingleByte() fh.ID = r.ReadUint32() r.ReadBytes(len(fh.reserved)) return r.Err() } func (fh *FrameHeader) write(w *typed.WriteBuffer) error { w.WriteUint16(fh.size) w.WriteSingleByte(byte(fh.messageType)) w.WriteSingleByte(fh.reserved1) w.WriteUint32(fh.ID) w.WriteBytes(fh.reserved[:]) return w.Err() } // A Frame is a header and payload type Frame struct { buffer []byte // full buffer, including payload and header headerBuffer []byte // slice referencing just the header // The header for the frame Header FrameHeader // The payload for the frame Payload []byte } // NewFrame allocates a new frame with the given payload capacity func NewFrame(payloadCapacity int) *Frame { f := &Frame{} f.buffer = make([]byte, payloadCapacity+FrameHeaderSize) f.Payload = f.buffer[FrameHeaderSize:] f.headerBuffer = f.buffer[:FrameHeaderSize] return f } // ReadBody takes in a previously read frame header, and only reads in the body // based on the size specified in the header. This allows callers to defer // the frame allocation till the body needs to be read. func (f *Frame) ReadBody(header []byte, r io.Reader) error { // Copy the header into the underlying buffer so we have an assembled frame // that can be directly forwarded. copy(f.buffer, header) // Parse the header into our typed struct. if err := f.Header.read(typed.NewReadBuffer(header)); err != nil { return err } switch payloadSize := f.Header.PayloadSize(); { case payloadSize > MaxFramePayloadSize: return fmt.Errorf("invalid frame size %v", f.Header.size) case payloadSize > 0: _, err := io.ReadFull(r, f.SizedPayload()) return err default: // No payload to read return nil } } // ReadIn reads the frame from the given io.Reader. // Deprecated: Only maintained for backwards compatibility. Callers should // use ReadBody instead. func (f *Frame) ReadIn(r io.Reader) error { header := make([]byte, FrameHeaderSize) if _, err := io.ReadFull(r, header); err != nil { return err } return f.ReadBody(header, r) } // WriteOut writes the frame to the given io.Writer func (f *Frame) WriteOut(w io.Writer) error { var wbuf typed.WriteBuffer wbuf.Wrap(f.headerBuffer) if err := f.Header.write(&wbuf); err != nil { return err } fullFrame := f.buffer[:f.Header.FrameSize()] if _, err := w.Write(fullFrame); err != nil { return err } return nil } // SizedPayload returns the slice of the payload actually used, as defined by the header func (f *Frame) SizedPayload() []byte { return f.Payload[:f.Header.PayloadSize()] } // messageType returns the message type. func (f *Frame) messageType() messageType { return f.Header.messageType } func (f *Frame) write(msg message) error { var wbuf typed.WriteBuffer wbuf.Wrap(f.Payload[:]) if err := msg.write(&wbuf); err != nil { return err } f.Header.ID = msg.ID() f.Header.reserved1 = 0 f.Header.messageType = msg.messageType() f.Header.SetPayloadSize(uint16(wbuf.BytesWritten())) return nil } func (f *Frame) read(msg message) error { var rbuf typed.ReadBuffer rbuf.Wrap(f.SizedPayload()) return msg.read(&rbuf) } ================================================ FILE: frame_pool.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "sync" // A FramePool is a pool for managing and re-using frames type FramePool interface { // Retrieves a new frame from the pool Get() *Frame // Releases a frame back to the pool Release(f *Frame) } // DefaultFramePool uses the SyncFramePool. var DefaultFramePool = NewSyncFramePool() // DisabledFramePool is a pool that uses the heap and relies on GC. var DisabledFramePool = disabledFramePool{} type disabledFramePool struct{} func (p disabledFramePool) Get() *Frame { return NewFrame(MaxFramePayloadSize) } func (p disabledFramePool) Release(f *Frame) {} type syncFramePool struct { pool *sync.Pool } // NewSyncFramePool returns a frame pool that uses a sync.Pool. func NewSyncFramePool() FramePool { return &syncFramePool{ pool: &sync.Pool{New: func() interface{} { return NewFrame(MaxFramePayloadSize) }}, } } func (p syncFramePool) Get() *Frame { return p.pool.Get().(*Frame) } func (p syncFramePool) Release(f *Frame) { p.pool.Put(f) } type channelFramePool chan *Frame // NewChannelFramePool returns a frame pool backed by a channel that has a max capacity. func NewChannelFramePool(capacity int) FramePool { return channelFramePool(make(chan *Frame, capacity)) } func (c channelFramePool) Get() *Frame { select { case frame := <-c: return frame default: return NewFrame(MaxFramePayloadSize) } } func (c channelFramePool) Release(f *Frame) { select { case c <- f: default: // Too many frames in the channel, discard it. } } ================================================ FILE: frame_pool_b_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "math/rand" "sync" "testing" . "github.com/uber/tchannel-go" "go.uber.org/atomic" ) func benchmarkUsing(b *testing.B, pool FramePool) { const numGoroutines = 1000 const maxHoldFrames = 1000 var gotFrames atomic.Uint64 var wg sync.WaitGroup for i := 0; i < numGoroutines; i++ { wg.Add(1) go func() { for { if gotFrames.Load() > uint64(b.N) { break } framesToHold := rand.Intn(maxHoldFrames) gotFrames.Add(uint64(framesToHold)) frames := make([]*Frame, framesToHold) for i := 0; i < framesToHold; i++ { frames[i] = pool.Get() } for i := 0; i < framesToHold; i++ { pool.Release(frames[i]) } } wg.Done() }() } wg.Wait() } func BenchmarkFramePoolDisabled(b *testing.B) { benchmarkUsing(b, DisabledFramePool) } func BenchmarkFramePoolSync(b *testing.B) { benchmarkUsing(b, NewSyncFramePool()) } func BenchmarkFramePoolChannel1000(b *testing.B) { benchmarkUsing(b, NewChannelFramePool(1000)) } func BenchmarkFramePoolChannel10000(b *testing.B) { benchmarkUsing(b, NewChannelFramePool(10000)) } ================================================ FILE: frame_pool_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test // This file contains functions for tests to access internal tchannel state. // Since it has a _test.go suffix, it is only compiled with tests in this package. import ( "bytes" "io" "math/rand" "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/testreader" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) type swapper struct { t testing.TB } func (s *swapper) OnError(ctx context.Context, err error) { s.t.Errorf("OnError: %v", err) } func (*swapper) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ Arg2: args.Arg3, Arg3: args.Arg2, }, nil } func doPingAndCall(t testing.TB, clientCh *Channel, hostPort string) { ctx, cancel := NewContext(time.Second * 5) defer cancel() require.NoError(t, clientCh.Ping(ctx, hostPort)) const maxRandArg = 512 * 1024 arg2 := testutils.RandBytes(rand.Intn(maxRandArg)) arg3 := testutils.RandBytes(rand.Intn(maxRandArg)) resArg2, resArg3, _, err := raw.Call(ctx, clientCh, hostPort, "swap-server", "swap", arg2, arg3) if !assert.NoError(t, err, "error during sendRecv") { return } // We expect the arguments to be swapped. if bytes.Compare(arg3, resArg2) != 0 { t.Errorf("returned arg2 does not match expected:\n got %v\n want %v", resArg2, arg3) } if bytes.Compare(arg2, resArg3) != 0 { t.Errorf("returned arg2 does not match expected:\n got %v\n want %v", resArg3, arg2) } } func doErrorCall(t testing.TB, clientCh *Channel, hostPort string) { ctx, cancel := NewContext(time.Second * 5) defer cancel() _, _, _, err := raw.Call(ctx, clientCh, hostPort, "swap-server", "non-existent", nil, nil) assert.Error(t, err, "Call to non-existent endpoint should fail") assert.Equal(t, ErrCodeBadRequest, GetSystemErrorCode(err), "Error code mismatch") } func TestFramesReleased(t *testing.T) { CheckStress(t) defer testutils.SetTimeout(t, 30*time.Second)() const ( requestsPerGoroutine = 10 numGoroutines = 10 ) pool := NewCheckedFramePoolForTest() opts := testutils.NewOpts(). SetServiceName("swap-server"). SetFramePool(pool). AddLogFilter("Couldn't find handler.", 2*numGoroutines*requestsPerGoroutine) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(&swapper{t}), "swap") clientOpts := testutils.NewOpts().SetFramePool(pool) clientCh := ts.NewClient(clientOpts) // Create an active connection that can be shared by the goroutines by calling Ping. ctx, cancel := NewContext(time.Second) defer cancel() require.NoError(t, clientCh.Ping(ctx, ts.HostPort())) var wg sync.WaitGroup for i := 0; i < numGoroutines; i++ { wg.Add(1) go func() { defer wg.Done() for i := 0; i < requestsPerGoroutine; i++ { doPingAndCall(t, clientCh, ts.HostPort()) doErrorCall(t, clientCh, ts.HostPort()) } }() } wg.Wait() }) CheckFramePoolIsEmpty(t, pool) } type dirtyFramePool struct{} func (p dirtyFramePool) Get() *Frame { f := NewFrame(MaxFramePayloadSize) reader := testreader.Looper([]byte{^byte(0)}) io.ReadFull(reader, f.Payload) return f } func (p dirtyFramePool) Release(f *Frame) {} func TestDirtyFrameRequests(t *testing.T) { argSizes := []int{25000, 50000, 75000} // Create the largest required random cache. testutils.RandBytes(argSizes[len(argSizes)-1]) opts := testutils.NewOpts(). SetServiceName("swap-server"). SetFramePool(dirtyFramePool{}) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(&swapper{t}), "swap") for _, argSize := range argSizes { ctx, cancel := NewContext(time.Second) defer cancel() arg2, arg3 := testutils.RandBytes(argSize), testutils.RandBytes(argSize) res2, res3, _, err := raw.Call(ctx, ts.Server(), ts.HostPort(), ts.Server().ServiceName(), "swap", arg2, arg3) if assert.NoError(t, err, "Call failed") { assert.Equal(t, arg2, res3, "Result arg3 wrong") assert.Equal(t, arg3, res2, "Result arg3 wrong") } } }) } ================================================ FILE: frame_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "encoding/binary" "encoding/json" "io" "math" "testing" "testing/iotest" "testing/quick" "github.com/uber/tchannel-go/testutils/testreader" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func fakeHeader(t messageType) FrameHeader { return FrameHeader{ size: uint16(0xFF34), messageType: t, ID: 0xDEADBEEF, } } func TestFrameHeaderJSON(t *testing.T) { fh := fakeHeader(messageTypeCallReq) logged, err := json.Marshal(fh) assert.NoError(t, err, "FrameHeader can't be marshalled to JSON") assert.Equal( t, string(logged), `{"id":3735928559,"msgType":3,"size":65332}`, "FrameHeader didn't marshal to JSON as expected", ) } func TestFraming(t *testing.T) { fh := fakeHeader(messageTypeCallReq) wbuf := typed.NewWriteBufferWithSize(1024) require.Nil(t, fh.write(wbuf)) var b bytes.Buffer if _, err := wbuf.FlushTo(&b); err != nil { require.Nil(t, err) } rbuf := typed.NewReadBuffer(b.Bytes()) var fh2 FrameHeader require.Nil(t, fh2.read(rbuf)) assert.Equal(t, fh, fh2) } func TestPartialRead(t *testing.T) { f := NewFrame(MaxFramePayloadSize) f.Header.size = FrameHeaderSize + 2134 f.Header.messageType = messageTypeCallReq f.Header.ID = 0xDEADBEED // We set the full payload but only the first 2134 bytes should be written. for i := 0; i < len(f.Payload); i++ { val := (i * 37) % 256 f.Payload[i] = byte(val) } buf := &bytes.Buffer{} require.NoError(t, f.WriteOut(buf)) assert.Equal(t, f.Header.size, uint16(buf.Len()), "frame size should match written bytes") // Read the data back, from a reader that fragments. f2 := NewFrame(MaxFramePayloadSize) require.NoError(t, f2.ReadIn(iotest.OneByteReader(buf))) // Ensure header and payload are the same. require.Equal(t, f.Header, f2.Header, "frame headers don't match") require.Equal(t, f.SizedPayload(), f2.SizedPayload(), "payload does not match") } func TestFrameReadShortFrame(t *testing.T) { headerFull := make([]byte, FrameHeaderSize) headerFull[1] = FrameHeaderSize + 1 // give the frame a non-zero size. body := []byte{1} f := NewFrame(MaxFramePayloadSize) err := f.ReadBody(headerFull, bytes.NewReader(body)) require.NoError(t, err, "Should not fail to read full frame header") for i := 0; i < FrameHeaderSize; i++ { partialHeader := headerFull[:i] f := NewFrame(MaxFramePayloadSize) err := f.ReadBody(partialHeader, bytes.NewReader(body)) assert.Equal(t, typed.ErrEOF, err, "Expected short header to fail") } } func TestEmptyPayload(t *testing.T) { f := NewFrame(MaxFramePayloadSize) m := &pingRes{id: 1} require.NoError(t, f.write(m)) // Write out the frame. buf := &bytes.Buffer{} require.NoError(t, f.WriteOut(buf)) assert.Equal(t, FrameHeaderSize, buf.Len()) // Read the frame from the buffer. // net.Conn returns io.EOF if you try to read 0 bytes at the end. // This is also simulated by the LimitedReader so we use that here. require.NoError(t, f.ReadIn(&io.LimitedReader{R: buf, N: FrameHeaderSize})) } func TestReservedBytes(t *testing.T) { // Set up a frame with non-zero values f := NewFrame(MaxFramePayloadSize) reader := testreader.Looper([]byte{^byte(0)}) io.ReadFull(reader, f.Payload) f.Header.read(typed.NewReadBuffer(f.Payload)) m := &pingRes{id: 1} f.write(m) buf := &bytes.Buffer{} f.WriteOut(buf) assert.Equal(t, []byte{ 0x0, 0x10, // size 0xd1, // type 0x0, // reserved should always be 0 0x0, 0x0, 0x0, 0x1, // id 0x0, 0x0, 0x0, 0x0, // reserved should always be 0 0x0, 0x0, 0x0, 0x0, // reserved should always be 0 }, buf.Bytes(), "Unexpected bytes") } func TestMessageType(t *testing.T) { frame := NewFrame(MaxFramePayloadSize) err := frame.write(&callReq{Service: "foo"}) require.NoError(t, err, "Error writing message to frame.") assert.Equal(t, messageTypeCallReq, frame.messageType(), "Failed to read message type from frame.") } func TestFrameReadIn(t *testing.T) { maxPayload := bytes.Repeat([]byte{1}, MaxFramePayloadSize) tests := []struct { msg string bs []byte wantFrameHeader FrameHeader wantFramePayload []byte wantErr string }{ { msg: "frame with no payload", bs: []byte{ 0, 16 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ 9, 8, 7, 6, 5, 4, 3, 2, // reserved }, wantFrameHeader: FrameHeader{ size: 16, messageType: 1, reserved1: 2, ID: 3, // reserved: [8]byte{9, 8, 7, 6, 5, 4, 3, 2}, // currently ignored. }, wantFramePayload: []byte{}, }, { msg: "frame with small payload", bs: []byte{ 0, 18 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ 9, 8, 7, 6, 5, 4, 3, 2, // reserved 100, 200, // payload }, wantFrameHeader: FrameHeader{ size: 18, messageType: 1, reserved1: 2, ID: 3, // reserved: [8]byte{9, 8, 7, 6, 5, 4, 3, 2}, // currently ignored. }, wantFramePayload: []byte{100, 200}, }, { msg: "frame with max size", bs: append([]byte{ math.MaxUint8, math.MaxUint8 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ 9, 8, 7, 6, 5, 4, 3, 2, // reserved }, maxPayload...), wantFrameHeader: FrameHeader{ size: math.MaxUint16, messageType: 1, reserved1: 2, ID: 3, // currently ignored. // reserved: [8]byte{9, 8, 7, 6, 5, 4, 3, 2}, }, wantFramePayload: maxPayload, }, { msg: "frame with 0 size", bs: []byte{ 0, 0 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ 9, 8, 7, 6, 5, 4, 3, 2, // reserved }, wantErr: "invalid frame size 0", }, { msg: "frame with size < HeaderSize", bs: []byte{ 0, 15 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ 9, 8, 7, 6, 5, 4, 3, 2, // reserved }, wantErr: "invalid frame size 15", }, { msg: "frame with partial header", bs: []byte{ 0, 16 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ // missing reserved bytes }, wantErr: "unexpected EOF", }, { msg: "frame with partial payload", bs: []byte{ 0, 24 /* size */, 1 /* type */, 2 /* reserved */, 0, 0, 0, 3, /* id */ 9, 8, 7, 6, 5, 4, 3, 2, // reserved 1, 2, // partial payload }, wantErr: "unexpected EOF", }, } for _, tt := range tests { f := DefaultFramePool.Get() r := bytes.NewReader(tt.bs) err := f.ReadIn(r) if tt.wantErr != "" { require.Error(t, err, tt.msg) assert.Contains(t, err.Error(), tt.wantErr, tt.msg) continue } require.NoError(t, err, tt.msg) assert.Equal(t, tt.wantFrameHeader, f.Header, "%v: header mismatch", tt.msg) assert.Equal(t, tt.wantFramePayload, f.SizedPayload(), "%v: unexpected payload") } } func frameReadIn(bs []byte) (decoded bool) { frame := DefaultFramePool.Get() defer DefaultFramePool.Release(frame) defer func() { if r := recover(); r != nil { decoded = false } }() frame.ReadIn(bytes.NewReader(bs)) return true } func TestQuickFrameReadIn(t *testing.T) { // Try to read any set of bytes as a frame. err := quick.Check(frameReadIn, &quick.Config{MaxCount: 10000}) require.NoError(t, err, "Failed to fuzz test ReadIn") // Limit the search space to just headers. err = quick.Check(func(size uint16, t byte, id uint32) bool { bs := make([]byte, FrameHeaderSize) binary.BigEndian.PutUint16(bs[0:2], size) bs[2] = t binary.BigEndian.PutUint32(bs[4:8], id) return frameReadIn(bs) }, &quick.Config{MaxCount: 10000}) require.NoError(t, err, "Failed to fuzz test ReadIn") } ================================================ FILE: frame_utils_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "sync" "unsafe" "github.com/prashantv/protectmem" ) type protectMemAllocs struct { frameAlloc *protectmem.Allocation bufferAlloc *protectmem.Allocation } type ProtectMemFramePool struct { sync.Mutex allocations map[*Frame]protectMemAllocs } // NewProtectMemFramePool creates a frame pool that ensures that released frames // are not reused by removing all access to a frame once it's been released. func NewProtectMemFramePool() FramePool { return &ProtectMemFramePool{ allocations: make(map[*Frame]protectMemAllocs), } } func (p *ProtectMemFramePool) Get() *Frame { frameAlloc := protectmem.Allocate(unsafe.Sizeof(Frame{})) f := (*Frame)(frameAlloc.Ptr()) bufferAlloc := protectmem.AllocateSlice(&f.buffer, MaxFramePayloadSize) f.buffer = f.buffer[:MaxFramePayloadSize] f.Payload = f.buffer[FrameHeaderSize:] f.headerBuffer = f.buffer[:FrameHeaderSize] p.Lock() p.allocations[f] = protectMemAllocs{ frameAlloc: frameAlloc, bufferAlloc: bufferAlloc, } p.Unlock() return f } func (p *ProtectMemFramePool) Release(f *Frame) { p.Lock() allocs, ok := p.allocations[f] delete(p.allocations, f) p.Unlock() if !ok { panic(fmt.Errorf("released frame that was not allocated by pool: %v", f.Header)) } allocs.bufferAlloc.Protect(protectmem.None) allocs.frameAlloc.Protect(protectmem.None) } ================================================ FILE: go.mod ================================================ module github.com/uber/tchannel-go go 1.21 require ( github.com/HdrHistogram/hdrhistogram-go v0.9.0 // indirect github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b github.com/cactus/go-statsd-client/statsd v0.0.0-20190922033735-5ca90424ceb7 github.com/crossdock/crossdock-go v0.0.0-20160816171116-049aabb0122b github.com/jessevdk/go-flags v1.4.0 github.com/opentracing/opentracing-go v1.1.0 github.com/prashantv/protectmem v0.0.0-20171002184600-e20412882b3a github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d github.com/stretchr/testify v1.5.1 github.com/uber-go/tally v3.3.15+incompatible github.com/uber/jaeger-client-go v2.22.1+incompatible go.uber.org/atomic v1.6.0 go.uber.org/multierr v1.2.0 golang.org/x/net v0.14.0 golang.org/x/sys v0.11.0 gopkg.in/yaml.v2 v2.4.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.3.0 // indirect github.com/uber/jaeger-lib v2.4.1+incompatible // indirect golang.org/x/tools v0.1.12 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect ) ================================================ FILE: go.sum ================================================ github.com/HdrHistogram/hdrhistogram-go v0.9.0 h1:dpujRju0R4M/QZzcnR1LH1qm+TVG3UzkWdp5tH1WMcg= github.com/HdrHistogram/hdrhistogram-go v0.9.0/go.mod h1:nxrse8/Tzg2tg3DZcZjm6qEclQKK70g0KxO61gFFZD4= github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b h1:AP/Y7sqYicnjGDfD5VcY4CIfh1hRXBUavxrvELjTiOE= github.com/bmizerany/perks v0.0.0-20141205001514-d9a9656a3a4b/go.mod h1:ac9efd0D1fsDb3EJvhqgXRbFx7bs2wqZ10HQPeU8U/Q= github.com/cactus/go-statsd-client/statsd v0.0.0-20190922033735-5ca90424ceb7 h1:QjgH6kpBzpFeQKXnpa6cdfg4F2heAG2sP3CZG+fGS+8= github.com/cactus/go-statsd-client/statsd v0.0.0-20190922033735-5ca90424ceb7/go.mod h1:D4RDtP0MffJ3+R36OkGul0LwJLIN8nRb0Ac6jZmJCmo= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/crossdock/crossdock-go v0.0.0-20160816171116-049aabb0122b h1:WR1qVJzbvrVywhAk4kMQKRPx09AZVI0NdEdYs59iHcA= github.com/crossdock/crossdock-go v0.0.0-20160816171116-049aabb0122b/go.mod h1:v9FBN7gdVTpiD/+LZ7Po0UKvROyT87uLVxTHVky/dlQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/protectmem v0.0.0-20171002184600-e20412882b3a h1:AA9vgIBDjMHPC2McaGPojgV2dcI78ZC0TLNhYCXEKH8= github.com/prashantv/protectmem v0.0.0-20171002184600-e20412882b3a/go.mod h1:lzZQ3Noex5pfAy7mkAeCjcBDteYU85uWWnJ/y6gKU8k= github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe h1:gD4vkYmuoWVgdV6UwI3tPo9MtMfVoIRY+Xn9919SJBg= github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe/go.mod h1:Vrkh1pnjV9Bl8c3P9zH0/D4NlOHWP5d4/hF4YTULaec= github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs= github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As= github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/uber-go/tally v3.3.15+incompatible h1:9hLSgNBP28CjIaDmAuRTq9qV+UZY+9PcvAkXO4nNMwg= github.com/uber-go/tally v3.3.15+incompatible/go.mod h1:YDTIBxdXyOU/sCWilKB4bgyufu1cEi0jdVnRdxvjnmU= github.com/uber/jaeger-client-go v2.22.1+incompatible h1:NHcubEkVbahf9t3p75TOCR83gdUHXjRJvjoBh1yACsM= github.com/uber/jaeger-client-go v2.22.1+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= ================================================ FILE: guide/Thrift_Hyperbahn.md ================================================ # Set up a Go + Thrift + Hyperbahn Service The code matching this guide is [here](../examples/keyvalue). The TChannel+Thrift integration for Go uses code generated by thrift-gen. ## Dependencies Make sure your [GOPATH is set up](http://golang.org/doc/code.html) before following this guide. You'll need to `go get` the following: * github.com/uber/tchannel-go * github.com/uber/tchannel-go/hyperbahn * github.com/uber/tchannel-go/thrift * github.com/uber/tchannel-go/thrift/thrift-gen Use [Godep](https://github.com/tools/godep) to manage dependencies, as the API is still in development and will change. This example will assume that the service is created in the following directory: `$GOPATH/src/github.com/uber/tchannel-go/examples/keyvalue` You should use your own path and update your import paths accordingly. ## Thrift service Create a [Thrift](https://thrift.apache.org/) file to define your service. For this guide, we'll use: `keyvalue.thrift`: ```thrift service baseService { string HealthCheck() } exception KeyNotFound { 1: string key } exception InvalidKey {} service KeyValue extends baseService { // If the key does not start with a letter, InvalidKey is returned. // If the key does not exist, KeyNotFound is returned. string Get(1: string key) throws ( 1: KeyNotFound notFound 2: InvalidKey invalidKey) // Set returns InvalidKey is an invalid key is sent. void Set(1: string key, 2: string value) } // Returned when the user is not authorized for the Admin service. exception NotAuthorized {} service Admin extends baseService { void clearAll() throws (1: NotAuthorized notAuthorized) } ``` This Thrift specification defines two services: * `KeyValue`: A simple string key-value store. * `Admin`: Management for the key-value store. Both of these services inherit `baseService` and so inherit `HealthCheck`. The methods may return exceptions instead of the expected result, which are also defined in the specification. Once you have defined your service, you should generate the Thrift service and client libraries by running the following: ```bash cd $GOPATH/src/github.com/uber/tchannel-go/examples/keyvalue thrift-gen --generateThrift --inputFile keyvalue.thrift ``` This runs the Thrift compiler, and then generates the service and client bindings. You can run the commands manually as well: ```bash # Generate serialization/deserialization logic. thrift -r --gen go:thrift_import=github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift keyvalue.thrift # Generate TChannel service interfaces in the same directory where Thrift generates code. thrift-gen --inputFile "$THRIFTFILE" --outputFile "THRIFT_FILE_FOLDER/gen-go/thriftName/tchan-keyvalue.go" ``` ## Go server To get the server ready, the following needs to be done: 1. Create the TChannel which is the network layer protocol. 2. Create a handler to handle the methods defined in the Thrift definition, and register it with tchannel/thrift. 3. Create a Hyperbahn client and advertise your service with Hyperbahn. ### Create a TChannel Create a channel using [tchannel.NewChannel](http://godoc.org/github.com/uber/tchannel-go#NewChannel) and listen using [Channel.ListenAndServe](http://godoc.org/github.com/uber/tchannel-go#Channel.ListenAndServe). The address passed to Listen should be a remote IP that can be used for incoming connections from other machines. You can use [tchannel.ListenIP](http://godoc.org/github.com/uber/tchannel-go#ListenIP) which uses heuristics to determine a good remote IP. When creating a channel, you can pass additional [options](http://godoc.org/github.com/uber/tchannel-go#ChannelOptions). ### Create and register Thrift handler Create a custom type with methods required by the Thrift generated interface. You can examine this interface by looking in `gen-go/keyvalue/tchan-keyvalue.go`. For example, the interface for our definition file looks like: ```go type TChanAdmin interface { HealthCheck(ctx thrift.Context) (string, error) ClearAll(ctx thrift.Context) error } type TChanKeyValue interface { Get(ctx thrift.Context, key string) (string, error) HealthCheck(ctx thrift.Context) (string, error) Set(ctx thrift.Context, key string, value string) error } ``` Create an instance of your handler type, and then create a [thrift.Server](http://godoc.org/github.com/uber/tchannel-go/thrift#NewServer) and [register](http://godoc.org/github.com/uber/tchannel-go/thrift#Server.Register) your Thrift handler. You can register multiple Thrift services on the same `thrift.Server`. Each handler method is run in a new goroutine and so must be thread-safe. Your handler methods can return two types of errors: * Errors declared in the Thrift file (e.g. `KeyNotFound`). * Unexpected errors. If you return an unexpected error, an error frame is sent over Thrift with the message. If there are known error cases, it is better to declare them in the Thrift file and return those explicitly, e.g.: ```go if value, ok := map[key]; ok { return value, "" } // Return a Thrift exception if the key is not found. return "", &keyvalue.KeyNotFound{Key: key} ``` ### Advertise with Hyperbahn Create a Hyperbahn client using [hyperbahn.NewClient](http://godoc.org/github.com/uber/tchannel-go/hyperbahn#NewClient) which requires a Hyperbahn configuration object that should be loaded from a configuration file for the current environment. You can also pass more [options](http://godoc.org/github.com/uber/tchannel-go/hyperbahn#ClientOptions) when creating the client. Call [Advertise](http://godoc.org/github.com/uber/tchannel-go/hyperbahn#Client.Advertise) to advertise the service with Hyperbahn. ### Serving Your service is now serving over Hyperbahn! You can test this by making a call using [tcurl](https://github.com/uber/tcurl): ``` node tcurl.js -p [HYPERBAHN-HOSTPORT] -t [DIR-TO-THRIFT] keyvalue KeyValue::Set -3 '{"key": "hello", "value": "world"}' node tcurl.js -p [HYPERBAHN-HOSTPORT] -t [DIR-TO-THRIFT] keyvalue KeyValue::Get -3 '{"key": "hello"}' ``` Replace `[HYPERBAHN-HOSTPORT]` with the host:port of a Hyperbahn node, and `[DIR-TO-THRIFT]` with the directory where the .thrift file is stored. Your service can now be accessed from any language over Hyperbahn + TChannel! ## Go client Note: The client implementation is still in active development. To make a client that talks, you need to: 1. Create a TChannel (or re-use an existing TChannel) 2. Set up Hyperbahn 3. Create a Thrift+TChannel client. 4. Make remote calls using the Thrift client. ### Create a TChannel TChannels are bi-directional and so the client uses the same method as the server code (tchannel.NewChannel) to create a TChannel. You do not need to call ListenAndServe on the channel. Even though the channel does not host a service, a serviceName is required for TChannel. This serviceName should be unique to identify this client. You can use an existing TChannel which hosts a service to make client calls. ### Set up Hyperbahn Similar to the server code, create a new Hyperbahn client using hyperbahn.NewClient. You do not need to call Advertise, as the client does not have any services to advertise over Hyperbahn. If you have already set up an existing client for use with a server, then you do not need to do anything further. ### Create a Thrift client The Thrift client has two parts: 1. The `thrift.TChanClient` which is configured to hit a specific Hyperbahn service. 2. A generated client which uses an underlying `thrift.TChanClient` to call methods for a specific Thrift service. To create a `thrift.TChanClient`, use `thrift.NewClient`. This client can then be used to create a generated client: ```go thriftClient := thrift.NewClient(ch, "keyvalue", nil) client := keyvalue.NewTChanKeyValueClient(thriftClient) adminClient := keyvalue.NewTChanAdminClient(thriftClient) ``` ### Make remote calls Method calls on the client make remote calls over TChannel. E.g. ```go err := client.Set(ctx, "hello", "world") val, err := client.Get(ctx, "hello") // val = "world" ``` You must pass a context when making method calls which passes the deadline, tracing information, and application headers. A simple root context is: ```go ctx, cancel := thrift.NewContext(time.Second) ``` All calls over TChannel are required to have a timeout, and tracing information. NewContext should only be used by edges, all other nodes should pass through the incoming Context. When you pass through a Context, you pass along the deadline, tracing information, and the headers. Note: Trace spans are automatically generated by TChannel, and the parent is set automatically from the current context's tracing span. ## Headers Thrift + TChannel allows clients to send headers (a list of string key/value pairs) and servers can add response headers to any response. In Go, headers are attached to a context before a call is made using [WithHeaders](http://godoc.org/github.com/uber/tchannel-go/thrift#WithHeaders): ```go headers := map[string]string{"user": "prashant"} ctx, cancel := thrift.NewContext(time.Second) ctx = thrift.WithHeaders(ctx) ``` The server can read these headers using [Headers](http://godoc.org/github.com/uber/tchannel-go/thrift#Context) and can set additional response headers using `SetResponseHeaders`: ```go func (h *kvHandler) ClearAll(ctx thrift.Context) { headers := ctx.Headers() // Application logic respHeaders := map[string]string{ "count": 10, } ctx.SetResponseHeaders(respHeaders) } ``` The client can read the response headers by calling `ctx.ResponseHeaders()` on the same context that was passed when making the call: ```go ctx := thrift.WithHeaders(thrift.NewContext(time.Second), headers) err := adminClient.ClearAll() // check error responseHeaders := ctx.ResponseHeaders() ``` Headers should not be used to pass arguments to the method - the Thrift request/response structs should be used for this. ## Limitations & Upcoming Changes TChannel's peer selection does not yet have a detailed health model for nodes, and selection does not balance load across nodes. The thrift-gen autogenerated code is new, and may not support all Thrift features (E.g. annotations, includes, multiple files) ================================================ FILE: handlers.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "reflect" "runtime" "sync" "golang.org/x/net/context" ) // A Handler is an object that can be registered with a Channel to process // incoming calls for a given service and method type Handler interface { // Handles an incoming call for service Handle(ctx context.Context, call *InboundCall) } // registrar is a subset of the Registrar interface, only containing Register. type registrar interface { Register(h Handler, methodName string) } // A HandlerFunc is an adapter to allow the use of ordinary functions as // Channel handlers. If f is a function with the appropriate signature, then // HandlerFunc(f) is a Handler object that calls f. type HandlerFunc func(ctx context.Context, call *InboundCall) // Handle calls f(ctx, call) func (f HandlerFunc) Handle(ctx context.Context, call *InboundCall) { f(ctx, call) } // An ErrorHandlerFunc is an adapter to allow the use of ordinary functions as // Channel handlers, with error handling convenience. If f is a function with // the appropriate signature, then ErrorHandlerFunc(f) is a Handler object that // calls f. type ErrorHandlerFunc func(ctx context.Context, call *InboundCall) error // Handle calls f(ctx, call) func (f ErrorHandlerFunc) Handle(ctx context.Context, call *InboundCall) { if err := f(ctx, call); err != nil { if GetSystemErrorCode(err) == ErrCodeUnexpected { call.log.WithFields(f.getLogFields()...).WithFields(ErrField(err)).Error("Unexpected handler error") } call.Response().SendSystemError(err) } } func (f ErrorHandlerFunc) getLogFields() LogFields { ptr := reflect.ValueOf(f).Pointer() handlerFunc := runtime.FuncForPC(ptr) // can't be nil fileName, fileLine := handlerFunc.FileLine(ptr) return LogFields{ {"handlerFuncName", handlerFunc.Name()}, {"handlerFuncFileName", fileName}, {"handlerFuncFileLine", fileLine}, } } // Manages handlers type handlerMap struct { sync.RWMutex handlers map[string]Handler } // Register implements registrar. func (hmap *handlerMap) Register(h Handler, method string) { hmap.Lock() defer hmap.Unlock() if hmap.handlers == nil { hmap.handlers = make(map[string]Handler) } hmap.handlers[method] = h } // Finds the handler matching the given service and method. See https://github.com/golang/go/issues/3512 // for the reason that method is []byte instead of a string func (hmap *handlerMap) find(method []byte) Handler { hmap.RLock() handler := hmap.handlers[string(method)] hmap.RUnlock() return handler } func (hmap *handlerMap) Handle(ctx context.Context, call *InboundCall) { c := call.conn h := hmap.find(call.Method()) if h == nil { c.log.WithFields( LogField{"serviceName", call.ServiceName()}, LogField{"method", call.MethodString()}, ).Error("Couldn't find handler.") call.Response().SendSystemError( NewSystemError(ErrCodeBadRequest, "no handler for service %q and method %q", call.ServiceName(), call.Method())) return } if c.log.Enabled(LogLevelDebug) { c.log.Debugf("Dispatching %s:%s from %s", call.ServiceName(), call.Method(), c.remotePeerInfo) } h.Handle(ctx, call) } // channelHandler is a Handler that wraps a Channel and delegates requests // to SubChannels based on the inbound call's service name. type channelHandler struct{ ch *Channel } func (c channelHandler) Handle(ctx context.Context, call *InboundCall) { c.ch.GetSubChannel(call.ServiceName()).handler.Handle(ctx, call) } // Register registers the handler on the channel's default service name. func (c channelHandler) Register(h Handler, methodName string) { c.ch.GetSubChannel(c.ch.PeerInfo().ServiceName).Register(h, methodName) } // userHandlerWithSkip is a Handler that wraps a localHandler backed by the channel. // and a user provided handler. // The inbound call will be handled by user handler, unless the call's // method name is configured to be handled by localHandler from ignore. type userHandlerWithSkip struct { localHandler channelHandler ignoreUserHandler map[string]struct{} // key is service::method userHandler Handler } func (u userHandlerWithSkip) Handle(ctx context.Context, call *InboundCall) { if _, ok := u.ignoreUserHandler[call.MethodString()]; ok { u.localHandler.Handle(ctx, call) return } u.userHandler.Handle(ctx, call) } func (u userHandlerWithSkip) Register(h Handler, methodName string) { u.localHandler.Register(h, methodName) } ================================================ FILE: handlers_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "testing" "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) type dummyHandler struct{} func (dummyHandler) Handle(ctx context.Context, call *InboundCall) {} func TestHandlers(t *testing.T) { const ( m1 = "m1" m2 = "m2" ) var ( hmap = &handlerMap{} h1 = &dummyHandler{} h2 = &dummyHandler{} m1b = []byte(m1) m2b = []byte(m2) ) assert.Nil(t, hmap.find(m1b)) assert.Nil(t, hmap.find(m2b)) hmap.Register(h1, m1) assert.Equal(t, h1, hmap.find(m1b)) assert.Nil(t, hmap.find(m2b)) hmap.Register(h2, m2) assert.Equal(t, h1, hmap.find(m1b)) assert.Equal(t, h2, hmap.find(m2b)) } ================================================ FILE: handlers_with_skip_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "go.uber.org/atomic" "golang.org/x/net/context" ) func procedure(svc, method string) string { return fmt.Sprintf("%s::%s", svc, method) } func TestUserHandlerWithSkip(t *testing.T) { const ( svc = "svc" userHandleMethod = "method" userHandleSkipMethod = "skipMethod" handleRuns = 3 handleSkipRuns = 5 ) userCounter, channelCounter := &recordHandler{}, &recordHandler{} opts := testutils.NewOpts().NoRelay() opts.ServiceName = svc opts.ChannelOptions = ChannelOptions{ Handler: userCounter, SkipHandlerMethods: []string{procedure(svc, userHandleSkipMethod)}, } testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // channel should be able to handle user ignored methods ts.Register(channelCounter, procedure(svc, userHandleSkipMethod)) client := ts.NewClient(nil) for i := 0; i < handleRuns; i++ { ctx, cancel := tchannel.NewContext(testutils.Timeout(300 * time.Millisecond)) defer cancel() raw.Call(ctx, client, ts.HostPort(), svc, procedure(svc, userHandleMethod), nil, nil) } assert.Equal(t, uint32(handleRuns), userCounter.c.Load(), "user provided handler not invoked correct amount of times") for i := 0; i < handleSkipRuns; i++ { ctx, cancel := tchannel.NewContext(testutils.Timeout(300 * time.Millisecond)) defer cancel() raw.Call(ctx, client, ts.HostPort(), svc, procedure(svc, userHandleSkipMethod), nil, nil) } assert.Equal(t, uint32(handleSkipRuns), channelCounter.c.Load(), "user provided handler not invoked correct amount of times") }) } func TestUserHandlerWithSkipInvalidInput(t *testing.T) { opts := &ChannelOptions{ Handler: &recordHandler{}, SkipHandlerMethods: []string{"notDelimitedByDoubleColons"}, } _, err := NewChannel("svc", opts) assert.EqualError(t, err, `each "SkipHandlerMethods" value should be of service::Method format but got "notDelimitedByDoubleColons"`) } type recordHandler struct{ c atomic.Uint32 } func (r *recordHandler) Handle(ctx context.Context, call *InboundCall) { r.c.Inc() } ================================================ FILE: health.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "sync" "time" "golang.org/x/net/context" ) const ( _defaultHealthCheckTimeout = time.Second _defaultHealthCheckFailuresToClose = 5 _healthHistorySize = 256 ) // HealthCheckOptions are the parameters to configure active TChannel health // checks. These are not intended to check application level health, but // TCP connection health (similar to TCP keep-alives). The health checks use // TChannel ping messages. type HealthCheckOptions struct { // The period between health checks. If this is zeor, active health checks // are disabled. Interval time.Duration // The timeout to use for a health check. // If no value is specified, it defaults to time.Second. Timeout time.Duration // FailuresToClose is the number of consecutive health check failures that // will cause this connection to be closed. // If no value is specified, it defaults to 5. FailuresToClose int } type healthHistory struct { sync.RWMutex states []bool insertAt int total int } func newHealthHistory() *healthHistory { return &healthHistory{ states: make([]bool, _healthHistorySize), } } func (hh *healthHistory) add(b bool) { hh.Lock() defer hh.Unlock() hh.states[hh.insertAt] = b hh.insertAt = (hh.insertAt + 1) % _healthHistorySize hh.total++ } func (hh *healthHistory) asBools() []bool { hh.RLock() defer hh.RUnlock() if hh.total < _healthHistorySize { return append([]bool(nil), hh.states[:hh.total]...) } states := hh.states copyStates := make([]bool, 0, _healthHistorySize) copyStates = append(copyStates, states[hh.insertAt:]...) copyStates = append(copyStates, states[:hh.insertAt]...) return copyStates } func (hco HealthCheckOptions) enabled() bool { return hco.Interval > 0 } func (hco HealthCheckOptions) withDefaults() HealthCheckOptions { if hco.Timeout == 0 { hco.Timeout = _defaultHealthCheckTimeout } if hco.FailuresToClose == 0 { hco.FailuresToClose = _defaultHealthCheckFailuresToClose } return hco } // healthCheck will do periodic pings on the connection to check the state of the connection. // We accept connID on the stack so can more easily debug panics or leaked goroutines. func (c *Connection) healthCheck(connID uint32) { defer close(c.healthCheckDone) opts := c.opts.HealthChecks ticker := c.timeTicker(opts.Interval) defer ticker.Stop() consecutiveFailures := 0 for { select { case <-ticker.C: case <-c.healthCheckCtx.Done(): return } ctx, cancel := context.WithTimeout(c.healthCheckCtx, opts.Timeout) err := c.ping(ctx) cancel() c.healthCheckHistory.add(err == nil) if err == nil { if c.log.Enabled(LogLevelDebug) { c.log.Debug("Performed successful active health check.") } consecutiveFailures = 0 continue } // If the health check failed because the connection closed or health // checks were stopped, we don't need to log or close the connection. if GetSystemErrorCode(err) == ErrCodeCancelled || err == ErrInvalidConnectionState { c.log.WithFields(ErrField(err)).Debug("Health checker stopped.") return } consecutiveFailures++ c.log.WithFields(LogFields{ {"consecutiveFailures", consecutiveFailures}, ErrField(err), {"failuresToClose", opts.FailuresToClose}, }...).Warn("Failed active health check.") if consecutiveFailures >= opts.FailuresToClose { c.close(LogFields{ {"reason", "health check failure"}, ErrField(err), }...) return } } } func (c *Connection) stopHealthCheck() { // Health checks are not enabled. if c.healthCheckDone == nil { return } // Best effort check to see if health checks were stopped. if c.healthCheckCtx.Err() != nil { return } c.log.Debug("Stopping health checks.") c.healthCheckQuit() <-c.healthCheckDone } ================================================ FILE: health_ext_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "strings" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHealthCheckStopBeforeStart(t *testing.T) { opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { var pingCount int frameRelay, cancel := testutils.FrameRelay(t, ts.HostPort(), func(outgoing bool, f *Frame) *Frame { if strings.Contains(f.Header.String(), "PingRes") { pingCount++ } return f }) defer cancel() ft := testutils.NewFakeTicker() opts := testutils.NewOpts(). SetTimeTicker(ft.New). SetHealthChecks(HealthCheckOptions{Interval: time.Second}) client := ts.NewClient(opts) ctx, cancel := NewContext(time.Second) defer cancel() conn, err := client.RootPeers().GetOrAdd(frameRelay).GetConnection(ctx) require.NoError(t, err, "Failed to get connection") conn.StopHealthCheck() // Should be no ping messages sent. for i := 0; i < 10; i++ { ft.TryTick() } assert.Equal(t, 0, pingCount, "No pings when health check is stopped") }) } func TestHealthCheckStopNoError(t *testing.T) { opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { var pingCount int frameRelay, cancel := testutils.FrameRelay(t, ts.HostPort(), func(outgoing bool, f *Frame) *Frame { if strings.Contains(f.Header.String(), "PingRes") { pingCount++ } return f }) defer cancel() ft := testutils.NewFakeTicker() opts := testutils.NewOpts(). SetTimeTicker(ft.New). SetHealthChecks(HealthCheckOptions{Interval: time.Second}). AddLogFilter("Unexpected ping response.", 1) client := ts.NewClient(opts) ctx, cancel := NewContext(time.Second) defer cancel() conn, err := client.RootPeers().GetOrAdd(frameRelay).GetConnection(ctx) require.NoError(t, err, "Failed to get connection") for i := 0; i < 10; i++ { ft.Tick() waitForNHealthChecks(t, conn, i+1) } conn.StopHealthCheck() // We stop the health check, so the ticks channel is no longer read, so // we can't use the synchronous tick here. for i := 0; i < 10; i++ { ft.TryTick() } assert.Equal(t, 10, pingCount, "Pings should stop after health check is stopped") }) } func TestHealthCheckIntegration(t *testing.T) { tests := []struct { msg string disable bool failuresToClose int pingResponses []bool wantActive bool wantHealthCheckLogs int }{ { msg: "no failures with failuresToClose=0", failuresToClose: 1, pingResponses: []bool{true, true, true, true}, wantActive: true, }, { msg: "single failure with failuresToClose=1", failuresToClose: 1, pingResponses: []bool{true, false}, wantActive: false, wantHealthCheckLogs: 1, }, { msg: "single failure with failuresToClose=2", failuresToClose: 2, pingResponses: []bool{true, false, true, false, true}, wantActive: true, wantHealthCheckLogs: 2, }, { msg: "up to 2 consecutive failures with failuresToClose=3", failuresToClose: 3, pingResponses: []bool{true, false, true, false, true, false, false, true, false, false, true}, wantActive: true, wantHealthCheckLogs: 6, }, { msg: "3 consecutive failures with failuresToClose=3", failuresToClose: 3, pingResponses: []bool{true, false, true, false, true, false, false, true, false, false, false}, wantActive: false, wantHealthCheckLogs: 7, }, } errFrame := getErrorFrame(t) for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { var pingCount int frameRelay, cancel := testutils.FrameRelay(t, ts.HostPort(), func(outgoing bool, f *Frame) *Frame { if strings.Contains(f.Header.String(), "PingRes") { success := tt.pingResponses[pingCount] pingCount++ if !success { errFrame.Header.ID = f.Header.ID f = errFrame } } return f }) defer cancel() ft := testutils.NewFakeTicker() opts := testutils.NewOpts(). SetTimeTicker(ft.New). SetHealthChecks(HealthCheckOptions{Interval: time.Second, FailuresToClose: tt.failuresToClose}). AddLogFilter("Failed active health check.", uint(tt.wantHealthCheckLogs)). AddLogFilter("Unexpected ping response.", 1) client := ts.NewClient(opts) ctx, cancel := NewContext(time.Second) defer cancel() conn, err := client.RootPeers().GetOrAdd(frameRelay).GetConnection(ctx) require.NoError(t, err, "Failed to get connection") for i := 0; i < len(tt.pingResponses); i++ { ft.TryTick() waitForNHealthChecks(t, conn, i+1) assert.Equal(t, tt.pingResponses[:i+1], introspectConn(conn).HealthChecks, "Unexpectd health check history") } // Once the health check is done, we trigger a Close, it's possible we are still // waiting for the connection to close. if tt.wantActive == false { testutils.WaitFor(time.Second, func() bool { return !conn.IsActive() }) } assert.Equal(t, tt.wantActive, conn.IsActive(), "Connection active mismatch") }) }) } } func waitForNHealthChecks(t testing.TB, conn *Connection, n int) { require.True(t, testutils.WaitFor(time.Second, func() bool { return len(introspectConn(conn).HealthChecks) >= n }), "Failed while waiting for %v health checks", n) } func introspectConn(c *Connection) ConnectionRuntimeState { return c.IntrospectState(&IntrospectionOptions{}) } ================================================ FILE: health_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "math/rand" "testing" "time" "github.com/stretchr/testify/assert" ) func TestHealthCheckEnabled(t *testing.T) { hc := HealthCheckOptions{} assert.False(t, hc.enabled(), "Default struct should not have health checks enabled") hc.Interval = time.Second assert.True(t, hc.enabled(), "Setting interval should enable health checks") } func TestHealthCheckOptionsDefaults(t *testing.T) { tests := []struct { opts HealthCheckOptions want HealthCheckOptions }{ { opts: HealthCheckOptions{}, want: HealthCheckOptions{Timeout: _defaultHealthCheckTimeout, FailuresToClose: _defaultHealthCheckFailuresToClose}, }, { opts: HealthCheckOptions{Timeout: 2 * time.Second}, want: HealthCheckOptions{Timeout: 2 * time.Second, FailuresToClose: _defaultHealthCheckFailuresToClose}, }, { opts: HealthCheckOptions{FailuresToClose: 3}, want: HealthCheckOptions{Timeout: _defaultHealthCheckTimeout, FailuresToClose: 3}, }, { opts: HealthCheckOptions{Timeout: 2 * time.Second, FailuresToClose: 3}, want: HealthCheckOptions{Timeout: 2 * time.Second, FailuresToClose: 3}, }, } for _, tt := range tests { got := tt.opts.withDefaults() assert.Equal(t, tt.want, got, "Unexpected defaults for %+v", tt.opts) } } func TestHealthHistory(t *testing.T) { hh := newHealthHistory() var want []bool for i := 0; i < 1000; i++ { assert.Equal(t, want, hh.asBools()) b := rand.Intn(3) > 0 hh.add(b) want = append(want, b) if len(want) > _healthHistorySize { want = want[1:] } } } ================================================ FILE: http/buf.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package http import ( "net/http" "github.com/uber/tchannel-go/typed" ) func writeHeaders(wb *typed.WriteBuffer, form http.Header) { numHeadersDeferred := wb.DeferUint16() numHeaders := uint16(0) for k, values := range form { for _, v := range values { wb.WriteLen16String(k) wb.WriteLen16String(v) numHeaders++ } } numHeadersDeferred.Update(numHeaders) } func readHeaders(rb *typed.ReadBuffer, form http.Header) { numHeaders := rb.ReadUint16() for i := 0; i < int(numHeaders); i++ { k := rb.ReadLen16String() v := rb.ReadLen16String() form[k] = append(form[k], v) } } func readVarintString(rb *typed.ReadBuffer) string { length := rb.ReadUvarint() return rb.ReadString(int(length)) } func writeVarintString(wb *typed.WriteBuffer, s string) { wb.WriteUvarint(uint64(len(s))) wb.WriteString(s) } ================================================ FILE: http/buf_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package http import ( "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/typed" ) func TestHeaders(t *testing.T) { tests := []http.Header{ {}, { "K1": []string{"K1V1", "K1V2", "K1V3"}, "K2": []string{"K2V2", "K2V2"}, }, } for _, tt := range tests { buf := make([]byte, 1000) wb := typed.NewWriteBuffer(buf) writeHeaders(wb, tt) newHeaders := make(http.Header) rb := typed.NewReadBuffer(buf) readHeaders(rb, newHeaders) assert.Equal(t, tt, newHeaders, "Headers mismatch") } } func TestVarintString(t *testing.T) { tests := []string{ "", "short string", testutils.RandString(1000), } for _, tt := range tests { buf := make([]byte, 2000) wb := typed.NewWriteBuffer(buf) writeVarintString(wb, tt) rb := typed.NewReadBuffer(buf) got := readVarintString(rb) assert.Equal(t, tt, got, "Varint string mismatch") } } ================================================ FILE: http/http_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package http import ( "bytes" "fmt" "io" "io/ioutil" "net" "net/http" "net/http/httputil" "strings" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) func dumpHandler(w http.ResponseWriter, r *http.Request) { r.URL.Host = "test.local" r.URL.Scheme = "http" // We cannot use httputil.DumpRequestOut as it prints the chunked encoding // while we only care about the data that the reader would see. dump := &bytes.Buffer{} dump.WriteString(r.Method) dump.WriteString(r.URL.String()) dump.WriteString("\n") dump.WriteString("Headers: ") dump.WriteString(fmt.Sprint(r.Form)) dump.WriteString("\n") dump.WriteString("Body: ") io.Copy(dump, r.Body) dump.WriteString("\n") w.Header().Add("My-Header-1", "V1") w.Header().Add("My-Header-1", "V2") w.Header().Add("My-Header-2", "V3") w.Write([]byte("Dumped request:\n")) w.Write(dump.Bytes()) } func setupHTTP(t *testing.T, serveMux *http.ServeMux) (string, func()) { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "net.Listen failed") go http.Serve(ln, serveMux) httpAddr := ln.Addr().String() return httpAddr, func() { ln.Close() } } func setupTChan(t *testing.T, mux *http.ServeMux) (string, func()) { ch := testutils.NewServer(t, testutils.NewOpts().SetServiceName("test")) handler := func(ctx context.Context, call *tchannel.InboundCall) { req, err := ReadRequest(call) if !assert.NoError(t, err, "ReadRequest failed") { return } // Make the HTTP call using the default mux. writer, finish := ResponseWriter(call.Response()) mux.ServeHTTP(writer, req) finish() } ch.Register(tchannel.HandlerFunc(handler), "http") return ch.PeerInfo().HostPort, func() { ch.Close() } } func setupProxy(t *testing.T, tchanAddr string) (string, func()) { mux := http.NewServeMux() mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // You get /proxy/host:port/rest/of/the/path parts := strings.SplitN(r.URL.Path, "/", 4) r.URL.Host = parts[2] r.URL.Scheme = "http" r.URL.Path = parts[3] ch := testutils.NewClient(t, nil) ctx, cancel := tchannel.NewContext(time.Second) defer cancel() call, err := ch.BeginCall(ctx, tchanAddr, "test", "http", nil) require.NoError(t, err, "BeginCall failed") require.NoError(t, WriteRequest(call, r), "WriteRequest failed") resp, err := ReadResponse(call.Response()) require.NoError(t, err, "Read response failed") for k, vs := range resp.Header { for _, v := range vs { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) _, err = io.Copy(w, resp.Body) assert.NoError(t, err, "io.Copy failed") err = resp.Body.Close() assert.NoError(t, err, "Close Response Body failed") })) return setupHTTP(t, mux) } // setupServer sets up a HTTP handler and a TChannel handler . func setupServer(t *testing.T) (string, string, func()) { mux := http.NewServeMux() mux.Handle("/", http.HandlerFunc(dumpHandler)) httpAddr, httpClose := setupHTTP(t, mux) tchanAddr, tchanClose := setupTChan(t, mux) close := func() { httpClose() tchanClose() } return httpAddr, tchanAddr, close } func makeHTTPCall(t *testing.T, req *http.Request) *http.Response { resp, err := http.DefaultClient.Do(req) require.NoError(t, err, "HTTP request failed") return resp } func makeTChanCall(t *testing.T, tchanAddr string, req *http.Request) *http.Response { ch := testutils.NewClient(t, nil) ctx, cancel := tchannel.NewContext(time.Second) defer cancel() call, err := ch.BeginCall(ctx, tchanAddr, "test", "http", nil) require.NoError(t, err, "BeginCall failed") require.NoError(t, WriteRequest(call, req), "WriteRequest failed") resp, err := ReadResponse(call.Response()) require.NoError(t, err, "Read response failed") return resp } func compareResponseBasic(t *testing.T, testName string, resp1, resp2 *http.Response) { resp1Body, err := ioutil.ReadAll(resp1.Body) require.NoError(t, err, "Read response failed") resp2Body, err := ioutil.ReadAll(resp2.Body) require.NoError(t, err, "Read response failed") assert.Equal(t, resp1.Status, resp2.Status, "%v: Response status mismatch", testName) assert.Equal(t, resp1.StatusCode, resp2.StatusCode, "%v: Response status code mismatch", testName) assert.Equal(t, string(resp1Body), string(resp2Body), "%v: Response body mismatch", testName) } func compareResponses(t *testing.T, testName string, resp1, resp2 *http.Response) { resp1Bs, err := httputil.DumpResponse(resp1, true) require.NoError(t, err, "Dump response") resp2Bs, err := httputil.DumpResponse(resp2, true) require.NoError(t, err, "Dump response") assert.Equal(t, string(resp1Bs), string(resp2Bs), "%v: Response mismatch", testName) } type requestTest struct { name string f func(string) *http.Request } func getRequestTests(t *testing.T) []requestTest { randBytes := testutils.RandBytes(40000) return []requestTest{ { name: "get simple", f: func(httpAddr string) *http.Request { req, err := http.NewRequest("GET", fmt.Sprintf("http://%v/this/is/my?req=1&v=2&v&a&a", httpAddr), nil) require.NoError(t, err, "NewRequest failed") return req }, }, { name: "post simple", f: func(httpAddr string) *http.Request { body := strings.NewReader("This is a simple POST body") req, err := http.NewRequest("POST", fmt.Sprintf("http://%v/post/path?v=1&b=3", httpAddr), body) require.NoError(t, err, "NewRequest failed") return req }, }, { name: "post random bytes", f: func(httpAddr string) *http.Request { body := bytes.NewReader(randBytes) req, err := http.NewRequest("POST", fmt.Sprintf("http://%v/post/path?v=1&b=3", httpAddr), body) require.NoError(t, err, "NewRequest failed") return req }, }, } } func TestDirectRequests(t *testing.T) { httpAddr, tchanAddr, finish := setupServer(t) defer finish() tests := getRequestTests(t) for _, tt := range tests { resp1 := makeHTTPCall(t, tt.f(httpAddr)) resp2 := makeTChanCall(t, tchanAddr, tt.f(httpAddr)) compareResponseBasic(t, tt.name, resp1, resp2) } } func TestProxyRequests(t *testing.T) { httpAddr, tchanAddr, finish := setupServer(t) defer finish() proxyAddr, finish := setupProxy(t, tchanAddr) defer finish() tests := getRequestTests(t) for _, tt := range tests { resp1 := makeHTTPCall(t, tt.f(httpAddr)) resp2 := makeHTTPCall(t, tt.f(proxyAddr+"/proxy/"+httpAddr)) // Delete the Date header since the calls are made at different times. resp1.Header.Del("Date") resp2.Header.Del("Date") compareResponses(t, tt.name, resp1, resp2) } } ================================================ FILE: http/request.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package http import ( "io" "net/http" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/typed" ) // WriteRequest writes a http.Request to the given writers. func WriteRequest(call tchannel.ArgWritable, req *http.Request) error { // TODO(prashant): Allow creating write buffers that let you grow the buffer underneath. wb := typed.NewWriteBufferWithSize(10000) wb.WriteLen8String(req.Method) writeVarintString(wb, req.URL.String()) writeHeaders(wb, req.Header) arg2Writer, err := call.Arg2Writer() if err != nil { return err } if _, err := wb.FlushTo(arg2Writer); err != nil { return err } if err := arg2Writer.Close(); err != nil { return err } arg3Writer, err := call.Arg3Writer() if err != nil { return err } if req.Body != nil { if _, err = io.Copy(arg3Writer, req.Body); err != nil { return err } } return arg3Writer.Close() } // ReadRequest reads a http.Request from the given readers. func ReadRequest(call tchannel.ArgReadable) (*http.Request, error) { var arg2 []byte if err := tchannel.NewArgReader(call.Arg2Reader()).Read(&arg2); err != nil { return nil, err } rb := typed.NewReadBuffer(arg2) method := rb.ReadLen8String() url := readVarintString(rb) r, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } readHeaders(rb, r.Header) if err := rb.Err(); err != nil { return nil, err } r.Body, err = call.Arg3Reader() return r, err } ================================================ FILE: http/response.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package http import ( "fmt" "io" "net/http" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/typed" ) // ReadResponse reads a http.Response from the given readers. func ReadResponse(call tchannel.ArgReadable) (*http.Response, error) { var arg2 []byte if err := tchannel.NewArgReader(call.Arg2Reader()).Read(&arg2); err != nil { return nil, err } rb := typed.NewReadBuffer(arg2) statusCode := rb.ReadUint16() message := readVarintString(rb) response := &http.Response{ StatusCode: int(statusCode), Status: fmt.Sprintf("%v %v", statusCode, message), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), } readHeaders(rb, response.Header) if err := rb.Err(); err != nil { return nil, err } arg3Reader, err := call.Arg3Reader() if err != nil { return nil, err } response.Body = arg3Reader return response, nil } type tchanResponseWriter struct { headers http.Header statusCode int response tchannel.ArgWritable arg3Writer io.WriteCloser err error } func newTChanResponseWriter(response tchannel.ArgWritable) *tchanResponseWriter { return &tchanResponseWriter{ headers: make(http.Header), statusCode: http.StatusOK, response: response, } } func (w *tchanResponseWriter) Header() http.Header { return w.headers } func (w *tchanResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode } // writeHeaders writes out the HTTP headers as arg2, and creates the arg3 writer. func (w *tchanResponseWriter) writeHeaders() { // TODO(prashant): Allow creating write buffers that let you grow the buffer underneath. wb := typed.NewWriteBufferWithSize(10000) wb.WriteUint16(uint16(w.statusCode)) writeVarintString(wb, http.StatusText(w.statusCode)) writeHeaders(wb, w.headers) arg2Writer, err := w.response.Arg2Writer() if err != nil { w.err = err return } if _, w.err = wb.FlushTo(arg2Writer); w.err != nil { return } if w.err = arg2Writer.Close(); w.err != nil { return } w.arg3Writer, w.err = w.response.Arg3Writer() } func (w *tchanResponseWriter) Write(bs []byte) (int, error) { if w.err != nil { return 0, w.err } if w.arg3Writer == nil { w.writeHeaders() } if w.err != nil { return 0, w.err } return w.arg3Writer.Write(bs) } func (w *tchanResponseWriter) finish() error { if w.arg3Writer == nil || w.err != nil { return w.err } return w.arg3Writer.Close() } // ResponseWriter returns a http.ResponseWriter that will write to an underlying writer. // It also returns a function that should be called once the handler has completed. func ResponseWriter(response tchannel.ArgWritable) (http.ResponseWriter, func() error) { responseWriter := newTChanResponseWriter(response) return responseWriter, responseWriter.finish } ================================================ FILE: hyperbahn/advertise.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package hyperbahn import ( "fmt" "math/rand" "time" "github.com/uber/tchannel-go" ) const ( // maxAdvertiseFailures is the number of consecutive advertise failures after // which we give up and trigger an OnError event. maxAdvertiseFailures = 5 // advertiseInterval is the base time interval between advertisements. advertiseInterval = 50 * time.Second // advertiseFuzzInterval is the maximum fuzz period to add to advertiseInterval. advertiseFuzzInterval = 20 * time.Second // advertiseRetryInterval is the unfuzzed base duration to wait before retry on the first // advertise failure. Successive retries will use 2 * previous base duration. advertiseRetryInterval = 1 * time.Second ) // ErrAdvertiseFailed is triggered when advertise fails. type ErrAdvertiseFailed struct { // WillRetry is set to true if advertise will be retried. WillRetry bool // Cause is the underlying error returned from the advertise call. Cause error } func (e ErrAdvertiseFailed) Error() string { return fmt.Sprintf("advertise failed, retry: %v, cause: %v", e.WillRetry, e.Cause) } // fuzzInterval returns a fuzzed version of the interval based on FullJitter as described here: // http://www.awsarchitectureblog.com/2015/03/backoff.html func fuzzInterval(interval time.Duration) time.Duration { return time.Duration(rand.Int63n(int64(interval))) } // fuzzedAdvertiseInterval returns the time to sleep between successful advertisements. func (c *Client) fuzzedAdvertiseInterval() time.Duration { return advertiseInterval + fuzzInterval(advertiseFuzzInterval) } // logFailedRegistrationRetry logs either a warning or info depending on the number of // consecutiveFailures. If consecutiveFailures > maxAdvertiseFailures, then we log a warning. func (c *Client) logFailedRegistrationRetry(errLogger tchannel.Logger, consecutiveFailures uint) { logFn := errLogger.Info if consecutiveFailures > maxAdvertiseFailures { logFn = errLogger.Warn } logFn("Hyperbahn client registration failed, will retry.") } // advertiseLoop readvertises the service approximately every minute (with some fuzzing). func (c *Client) advertiseLoop() { sleepFor := c.fuzzedAdvertiseInterval() consecutiveFailures := uint(0) for { c.sleep(sleepFor) if c.IsClosed() { c.tchan.Logger().Infof("Hyperbahn client closed") return } if err := c.sendAdvertise(); err != nil { consecutiveFailures++ errLogger := c.tchan.Logger().WithFields(tchannel.ErrField(err)) if consecutiveFailures >= maxAdvertiseFailures && c.opts.FailStrategy == FailStrategyFatal { c.opts.Handler.OnError(ErrAdvertiseFailed{Cause: err, WillRetry: false}) errLogger.Fatal("Hyperbahn client registration failed.") } c.logFailedRegistrationRetry(errLogger, consecutiveFailures) c.opts.Handler.OnError(ErrAdvertiseFailed{Cause: err, WillRetry: true}) // Even after many failures, cap backoff. if consecutiveFailures < maxAdvertiseFailures { sleepFor = fuzzInterval(advertiseRetryInterval * time.Duration(1<= maxAdvertiseFailures*2 { close(doneTesting) r.client.Close() } }) // For the last failure, we assert that the handler was called and // signal that the test is done. r.setAdvertiseSuccess() require.NoError(t, r.client.Advertise()) <-r.reqCh sleptFor := <-r.sleepArgs checkAdvertiseInterval(t, sleptFor) // Even after maxRegistrationFailures failures to register with // Hyperbahn, FailStrategyIgnore should keep retrying. for i := 1; i <= maxAdvertiseFailures*2; i++ { r.sleepBlock <- struct{}{} r.setAdvertiseFailure() <-r.reqCh sleptFor := <-r.sleepArgs // Make sure that we cap backoff at some reasonable duration, even // after many retries. if i <= maxAdvertiseFailures { checkRetryInterval(t, sleptFor, i) } else { checkRetryInterval(t, sleptFor, maxAdvertiseFailures) } } r.sleepClose() // Wait for the handler to be called and the mock expectation to be recorded. <-doneTesting }) } func checkAdvertiseInterval(t *testing.T, sleptFor time.Duration) { assert.True(t, sleptFor >= advertiseInterval, "advertise interval should be > advertiseInterval") assert.True(t, sleptFor < advertiseInterval+advertiseFuzzInterval, "advertise interval should be < advertiseInterval + advertiseFuzzInterval") } func checkRetryInterval(t *testing.T, sleptFor time.Duration, retryNum int) { maxRetryInterval := advertiseRetryInterval * time.Duration(1<= Event(len(_Event_index)) { return fmt.Sprintf("Event(%d)", i) } return _Event_name[_Event_index[i]:_Event_index[i+1]] } ================================================ FILE: hyperbahn/events.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package hyperbahn // Event describes different events that Client can trigger. type Event int const ( // UnknownEvent should never be used. UnknownEvent Event = iota // SendAdvertise is triggered when the Hyperbahn client tries to advertise. SendAdvertise // Advertised is triggered when the initial advertisement for a service is successful. Advertised // Readvertised is triggered on periodic advertisements. Readvertised ) //go:generate stringer -type=Event // Handler is the interface for handling Hyperbahn events and errors. type Handler interface { // On is called when events are triggered. On(event Event) // OnError is called when an error is detected. OnError(err error) } // nullHandler is the default Handler if nil is passed, so handlers can always be called. type nullHandler struct{} func (nullHandler) On(event Event) {} func (nullHandler) OnError(err error) {} ================================================ FILE: hyperbahn/gen-go/hyperbahn/constants.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package hyperbahn import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal func init() { } ================================================ FILE: hyperbahn/gen-go/hyperbahn/hyperbahn.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package hyperbahn import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type Hyperbahn interface { // Parameters: // - Query Discover(query *DiscoveryQuery) (r *DiscoveryResult_, err error) } type HyperbahnClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewHyperbahnClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *HyperbahnClient { return &HyperbahnClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewHyperbahnClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *HyperbahnClient { return &HyperbahnClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } // Parameters: // - Query func (p *HyperbahnClient) Discover(query *DiscoveryQuery) (r *DiscoveryResult_, err error) { if err = p.sendDiscover(query); err != nil { return } return p.recvDiscover() } func (p *HyperbahnClient) sendDiscover(query *DiscoveryQuery) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("discover", thrift.CALL, p.SeqId); err != nil { return } args := HyperbahnDiscoverArgs{ Query: query, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *HyperbahnClient) recvDiscover() (value *DiscoveryResult_, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "discover" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "discover failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "discover failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error1 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error2 error error2, err = error1.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error2 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "discover failed: invalid message type") return } result := HyperbahnDiscoverResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } if result.NoPeersAvailable != nil { err = result.NoPeersAvailable return } else if result.InvalidServiceName != nil { err = result.InvalidServiceName return } value = result.GetSuccess() return } type HyperbahnProcessor struct { processorMap map[string]thrift.TProcessorFunction handler Hyperbahn } func (p *HyperbahnProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *HyperbahnProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *HyperbahnProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewHyperbahnProcessor(handler Hyperbahn) *HyperbahnProcessor { self3 := &HyperbahnProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self3.processorMap["discover"] = &hyperbahnProcessorDiscover{handler: handler} return self3 } func (p *HyperbahnProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x4 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x4.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x4 } type hyperbahnProcessorDiscover struct { handler Hyperbahn } func (p *hyperbahnProcessorDiscover) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := HyperbahnDiscoverArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("discover", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := HyperbahnDiscoverResult{} var retval *DiscoveryResult_ var err2 error if retval, err2 = p.handler.Discover(args.Query); err2 != nil { switch v := err2.(type) { case *NoPeersAvailable: result.NoPeersAvailable = v case *InvalidServiceName: result.InvalidServiceName = v default: x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing discover: "+err2.Error()) oprot.WriteMessageBegin("discover", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("discover", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES // Attributes: // - Query type HyperbahnDiscoverArgs struct { Query *DiscoveryQuery `thrift:"query,1,required" db:"query" json:"query"` } func NewHyperbahnDiscoverArgs() *HyperbahnDiscoverArgs { return &HyperbahnDiscoverArgs{} } var HyperbahnDiscoverArgs_Query_DEFAULT *DiscoveryQuery func (p *HyperbahnDiscoverArgs) GetQuery() *DiscoveryQuery { if !p.IsSetQuery() { return HyperbahnDiscoverArgs_Query_DEFAULT } return p.Query } func (p *HyperbahnDiscoverArgs) IsSetQuery() bool { return p.Query != nil } func (p *HyperbahnDiscoverArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetQuery bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetQuery = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetQuery { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Query is not set")) } return nil } func (p *HyperbahnDiscoverArgs) ReadField1(iprot thrift.TProtocol) error { p.Query = &DiscoveryQuery{} if err := p.Query.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Query), err) } return nil } func (p *HyperbahnDiscoverArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("discover_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *HyperbahnDiscoverArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("query", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:query: ", p), err) } if err := p.Query.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Query), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:query: ", p), err) } return err } func (p *HyperbahnDiscoverArgs) String() string { if p == nil { return "" } return fmt.Sprintf("HyperbahnDiscoverArgs(%+v)", *p) } // Attributes: // - Success // - NoPeersAvailable // - InvalidServiceName type HyperbahnDiscoverResult struct { Success *DiscoveryResult_ `thrift:"success,0" db:"success" json:"success,omitempty"` NoPeersAvailable *NoPeersAvailable `thrift:"noPeersAvailable,1" db:"noPeersAvailable" json:"noPeersAvailable,omitempty"` InvalidServiceName *InvalidServiceName `thrift:"invalidServiceName,2" db:"invalidServiceName" json:"invalidServiceName,omitempty"` } func NewHyperbahnDiscoverResult() *HyperbahnDiscoverResult { return &HyperbahnDiscoverResult{} } var HyperbahnDiscoverResult_Success_DEFAULT *DiscoveryResult_ func (p *HyperbahnDiscoverResult) GetSuccess() *DiscoveryResult_ { if !p.IsSetSuccess() { return HyperbahnDiscoverResult_Success_DEFAULT } return p.Success } var HyperbahnDiscoverResult_NoPeersAvailable_DEFAULT *NoPeersAvailable func (p *HyperbahnDiscoverResult) GetNoPeersAvailable() *NoPeersAvailable { if !p.IsSetNoPeersAvailable() { return HyperbahnDiscoverResult_NoPeersAvailable_DEFAULT } return p.NoPeersAvailable } var HyperbahnDiscoverResult_InvalidServiceName_DEFAULT *InvalidServiceName func (p *HyperbahnDiscoverResult) GetInvalidServiceName() *InvalidServiceName { if !p.IsSetInvalidServiceName() { return HyperbahnDiscoverResult_InvalidServiceName_DEFAULT } return p.InvalidServiceName } func (p *HyperbahnDiscoverResult) IsSetSuccess() bool { return p.Success != nil } func (p *HyperbahnDiscoverResult) IsSetNoPeersAvailable() bool { return p.NoPeersAvailable != nil } func (p *HyperbahnDiscoverResult) IsSetInvalidServiceName() bool { return p.InvalidServiceName != nil } func (p *HyperbahnDiscoverResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } case 1: if err := p.ReadField1(iprot); err != nil { return err } case 2: if err := p.ReadField2(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *HyperbahnDiscoverResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &DiscoveryResult_{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *HyperbahnDiscoverResult) ReadField1(iprot thrift.TProtocol) error { p.NoPeersAvailable = &NoPeersAvailable{} if err := p.NoPeersAvailable.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.NoPeersAvailable), err) } return nil } func (p *HyperbahnDiscoverResult) ReadField2(iprot thrift.TProtocol) error { p.InvalidServiceName = &InvalidServiceName{} if err := p.InvalidServiceName.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.InvalidServiceName), err) } return nil } func (p *HyperbahnDiscoverResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("discover_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *HyperbahnDiscoverResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *HyperbahnDiscoverResult) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetNoPeersAvailable() { if err := oprot.WriteFieldBegin("noPeersAvailable", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:noPeersAvailable: ", p), err) } if err := p.NoPeersAvailable.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.NoPeersAvailable), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:noPeersAvailable: ", p), err) } } return err } func (p *HyperbahnDiscoverResult) writeField2(oprot thrift.TProtocol) (err error) { if p.IsSetInvalidServiceName() { if err := oprot.WriteFieldBegin("invalidServiceName", thrift.STRUCT, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:invalidServiceName: ", p), err) } if err := p.InvalidServiceName.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.InvalidServiceName), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:invalidServiceName: ", p), err) } } return err } func (p *HyperbahnDiscoverResult) String() string { if p == nil { return "" } return fmt.Sprintf("HyperbahnDiscoverResult(%+v)", *p) } ================================================ FILE: hyperbahn/gen-go/hyperbahn/tchan-hyperbahn.go ================================================ // @generated Code generated by thrift-gen. Do not modify. // Package hyperbahn is generated code used to make or handle TChannel calls using Thrift. package hyperbahn import ( "fmt" athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "github.com/uber/tchannel-go/thrift" ) // Interfaces for the service and client for the services defined in the IDL. // TChanHyperbahn is the interface that defines the server handler and client interface. type TChanHyperbahn interface { Discover(ctx thrift.Context, query *DiscoveryQuery) (*DiscoveryResult_, error) } // Implementation of a client and service handler. type tchanHyperbahnClient struct { thriftService string client thrift.TChanClient } func NewTChanHyperbahnInheritedClient(thriftService string, client thrift.TChanClient) *tchanHyperbahnClient { return &tchanHyperbahnClient{ thriftService, client, } } // NewTChanHyperbahnClient creates a client that can be used to make remote calls. func NewTChanHyperbahnClient(client thrift.TChanClient) TChanHyperbahn { return NewTChanHyperbahnInheritedClient("Hyperbahn", client) } func (c *tchanHyperbahnClient) Discover(ctx thrift.Context, query *DiscoveryQuery) (*DiscoveryResult_, error) { var resp HyperbahnDiscoverResult args := HyperbahnDiscoverArgs{ Query: query, } success, err := c.client.Call(ctx, c.thriftService, "discover", &args, &resp) if err == nil && !success { switch { case resp.NoPeersAvailable != nil: err = resp.NoPeersAvailable case resp.InvalidServiceName != nil: err = resp.InvalidServiceName default: err = fmt.Errorf("received no result or unknown exception for discover") } } return resp.GetSuccess(), err } type tchanHyperbahnServer struct { handler TChanHyperbahn } // NewTChanHyperbahnServer wraps a handler for TChanHyperbahn so it can be // registered with a thrift.Server. func NewTChanHyperbahnServer(handler TChanHyperbahn) thrift.TChanServer { return &tchanHyperbahnServer{ handler, } } func (s *tchanHyperbahnServer) Service() string { return "Hyperbahn" } func (s *tchanHyperbahnServer) Methods() []string { return []string{ "discover", } } func (s *tchanHyperbahnServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "discover": return s.handleDiscover(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanHyperbahnServer) handleDiscover(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req HyperbahnDiscoverArgs var res HyperbahnDiscoverResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Discover(ctx, req.Query) if err != nil { switch v := err.(type) { case *NoPeersAvailable: if v == nil { return false, nil, fmt.Errorf("Handler for noPeersAvailable returned non-nil error type *NoPeersAvailable but nil value") } res.NoPeersAvailable = v case *InvalidServiceName: if v == nil { return false, nil, fmt.Errorf("Handler for invalidServiceName returned non-nil error type *InvalidServiceName but nil value") } res.InvalidServiceName = v default: return false, nil, err } } else { res.Success = r } return err == nil, &res, nil } ================================================ FILE: hyperbahn/gen-go/hyperbahn/ttypes.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package hyperbahn import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal var GoUnusedProtection__ int // Attributes: // - Message // - ServiceName type NoPeersAvailable struct { Message string `thrift:"message,1,required" db:"message" json:"message"` ServiceName string `thrift:"serviceName,2,required" db:"serviceName" json:"serviceName"` } func NewNoPeersAvailable() *NoPeersAvailable { return &NoPeersAvailable{} } func (p *NoPeersAvailable) GetMessage() string { return p.Message } func (p *NoPeersAvailable) GetServiceName() string { return p.ServiceName } func (p *NoPeersAvailable) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetMessage bool = false var issetServiceName bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetMessage = true case 2: if err := p.ReadField2(iprot); err != nil { return err } issetServiceName = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetMessage { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Message is not set")) } if !issetServiceName { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field ServiceName is not set")) } return nil } func (p *NoPeersAvailable) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Message = v } return nil } func (p *NoPeersAvailable) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.ServiceName = v } return nil } func (p *NoPeersAvailable) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("NoPeersAvailable"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *NoPeersAvailable) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("message", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:message: ", p), err) } if err := oprot.WriteString(string(p.Message)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.message (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:message: ", p), err) } return err } func (p *NoPeersAvailable) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("serviceName", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:serviceName: ", p), err) } if err := oprot.WriteString(string(p.ServiceName)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.serviceName (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:serviceName: ", p), err) } return err } func (p *NoPeersAvailable) String() string { if p == nil { return "" } return fmt.Sprintf("NoPeersAvailable(%+v)", *p) } func (p *NoPeersAvailable) Error() string { return p.String() } // Attributes: // - Message // - ServiceName type InvalidServiceName struct { Message string `thrift:"message,1,required" db:"message" json:"message"` ServiceName string `thrift:"serviceName,2,required" db:"serviceName" json:"serviceName"` } func NewInvalidServiceName() *InvalidServiceName { return &InvalidServiceName{} } func (p *InvalidServiceName) GetMessage() string { return p.Message } func (p *InvalidServiceName) GetServiceName() string { return p.ServiceName } func (p *InvalidServiceName) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetMessage bool = false var issetServiceName bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetMessage = true case 2: if err := p.ReadField2(iprot); err != nil { return err } issetServiceName = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetMessage { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Message is not set")) } if !issetServiceName { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field ServiceName is not set")) } return nil } func (p *InvalidServiceName) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Message = v } return nil } func (p *InvalidServiceName) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.ServiceName = v } return nil } func (p *InvalidServiceName) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("InvalidServiceName"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *InvalidServiceName) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("message", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:message: ", p), err) } if err := oprot.WriteString(string(p.Message)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.message (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:message: ", p), err) } return err } func (p *InvalidServiceName) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("serviceName", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:serviceName: ", p), err) } if err := oprot.WriteString(string(p.ServiceName)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.serviceName (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:serviceName: ", p), err) } return err } func (p *InvalidServiceName) String() string { if p == nil { return "" } return fmt.Sprintf("InvalidServiceName(%+v)", *p) } func (p *InvalidServiceName) Error() string { return p.String() } // Attributes: // - ServiceName type DiscoveryQuery struct { ServiceName string `thrift:"serviceName,1,required" db:"serviceName" json:"serviceName"` } func NewDiscoveryQuery() *DiscoveryQuery { return &DiscoveryQuery{} } func (p *DiscoveryQuery) GetServiceName() string { return p.ServiceName } func (p *DiscoveryQuery) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetServiceName bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetServiceName = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetServiceName { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field ServiceName is not set")) } return nil } func (p *DiscoveryQuery) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.ServiceName = v } return nil } func (p *DiscoveryQuery) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("DiscoveryQuery"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *DiscoveryQuery) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("serviceName", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:serviceName: ", p), err) } if err := oprot.WriteString(string(p.ServiceName)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.serviceName (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:serviceName: ", p), err) } return err } func (p *DiscoveryQuery) String() string { if p == nil { return "" } return fmt.Sprintf("DiscoveryQuery(%+v)", *p) } // Attributes: // - Ipv4 type IpAddress struct { Ipv4 *int32 `thrift:"ipv4,1" db:"ipv4" json:"ipv4,omitempty"` } func NewIpAddress() *IpAddress { return &IpAddress{} } var IpAddress_Ipv4_DEFAULT int32 func (p *IpAddress) GetIpv4() int32 { if !p.IsSetIpv4() { return IpAddress_Ipv4_DEFAULT } return *p.Ipv4 } func (p *IpAddress) CountSetFieldsIpAddress() int { count := 0 if p.IsSetIpv4() { count++ } return count } func (p *IpAddress) IsSetIpv4() bool { return p.Ipv4 != nil } func (p *IpAddress) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *IpAddress) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Ipv4 = &v } return nil } func (p *IpAddress) Write(oprot thrift.TProtocol) error { if c := p.CountSetFieldsIpAddress(); c != 1 { return fmt.Errorf("%T write union: exactly one field must be set (%d set).", p, c) } if err := oprot.WriteStructBegin("IpAddress"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *IpAddress) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetIpv4() { if err := oprot.WriteFieldBegin("ipv4", thrift.I32, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:ipv4: ", p), err) } if err := oprot.WriteI32(int32(*p.Ipv4)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.ipv4 (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:ipv4: ", p), err) } } return err } func (p *IpAddress) String() string { if p == nil { return "" } return fmt.Sprintf("IpAddress(%+v)", *p) } // Attributes: // - IP // - Port type ServicePeer struct { IP *IpAddress `thrift:"ip,1,required" db:"ip" json:"ip"` Port int32 `thrift:"port,2,required" db:"port" json:"port"` } func NewServicePeer() *ServicePeer { return &ServicePeer{} } var ServicePeer_IP_DEFAULT *IpAddress func (p *ServicePeer) GetIP() *IpAddress { if !p.IsSetIP() { return ServicePeer_IP_DEFAULT } return p.IP } func (p *ServicePeer) GetPort() int32 { return p.Port } func (p *ServicePeer) IsSetIP() bool { return p.IP != nil } func (p *ServicePeer) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetIP bool = false var issetPort bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetIP = true case 2: if err := p.ReadField2(iprot); err != nil { return err } issetPort = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetIP { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field IP is not set")) } if !issetPort { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Port is not set")) } return nil } func (p *ServicePeer) ReadField1(iprot thrift.TProtocol) error { p.IP = &IpAddress{} if err := p.IP.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.IP), err) } return nil } func (p *ServicePeer) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.Port = v } return nil } func (p *ServicePeer) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("ServicePeer"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *ServicePeer) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("ip", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:ip: ", p), err) } if err := p.IP.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.IP), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:ip: ", p), err) } return err } func (p *ServicePeer) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("port", thrift.I32, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:port: ", p), err) } if err := oprot.WriteI32(int32(p.Port)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.port (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:port: ", p), err) } return err } func (p *ServicePeer) String() string { if p == nil { return "" } return fmt.Sprintf("ServicePeer(%+v)", *p) } // Attributes: // - Peers type DiscoveryResult_ struct { Peers []*ServicePeer `thrift:"peers,1,required" db:"peers" json:"peers"` } func NewDiscoveryResult_() *DiscoveryResult_ { return &DiscoveryResult_{} } func (p *DiscoveryResult_) GetPeers() []*ServicePeer { return p.Peers } func (p *DiscoveryResult_) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetPeers bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetPeers = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetPeers { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Peers is not set")) } return nil } func (p *DiscoveryResult_) ReadField1(iprot thrift.TProtocol) error { _, size, err := iprot.ReadListBegin() if err != nil { return thrift.PrependError("error reading list begin: ", err) } tSlice := make([]*ServicePeer, 0, size) p.Peers = tSlice for i := 0; i < size; i++ { _elem0 := &ServicePeer{} if err := _elem0.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem0), err) } p.Peers = append(p.Peers, _elem0) } if err := iprot.ReadListEnd(); err != nil { return thrift.PrependError("error reading list end: ", err) } return nil } func (p *DiscoveryResult_) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("DiscoveryResult"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *DiscoveryResult_) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("peers", thrift.LIST, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:peers: ", p), err) } if err := oprot.WriteListBegin(thrift.STRUCT, len(p.Peers)); err != nil { return thrift.PrependError("error writing list begin: ", err) } for _, v := range p.Peers { if err := v.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", v), err) } } if err := oprot.WriteListEnd(); err != nil { return thrift.PrependError("error writing list end: ", err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:peers: ", p), err) } return err } func (p *DiscoveryResult_) String() string { if p == nil { return "" } return fmt.Sprintf("DiscoveryResult_(%+v)", *p) } ================================================ FILE: hyperbahn/hyperbahn.thrift ================================================ exception NoPeersAvailable { 1: required string message 2: required string serviceName } exception InvalidServiceName { 1: required string message 2: required string serviceName } struct DiscoveryQuery { 1: required string serviceName } union IpAddress { 1: i32 ipv4 } struct ServicePeer { 1: required IpAddress ip 2: required i32 port } struct DiscoveryResult { 1: required list peers } service Hyperbahn { DiscoveryResult discover( 1: required DiscoveryQuery query ) throws ( 1: NoPeersAvailable noPeersAvailable 2: InvalidServiceName invalidServiceName ) } ================================================ FILE: hyperbahn/utils.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package hyperbahn import ( "net" "strconv" "github.com/uber/tchannel-go/hyperbahn/gen-go/hyperbahn" ) // intToIP4 converts an integer IP representation into a 4-byte net.IP struct func intToIP4(ip uint32) net.IP { return net.IP{ byte(ip >> 24 & 0xff), byte(ip >> 16 & 0xff), byte(ip >> 8 & 0xff), byte(ip & 0xff), } } // servicePeerToHostPort converts a Hyperbahn ServicePeer into a hostPort string. func servicePeerToHostPort(peer *hyperbahn.ServicePeer) string { host := intToIP4(uint32(*peer.IP.Ipv4)).String() port := strconv.Itoa(int(peer.Port)) return net.JoinHostPort(host, port) } ================================================ FILE: hyperbahn/utils_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package hyperbahn import ( "testing" "github.com/stretchr/testify/assert" ) func TestIntToIP4(t *testing.T) { tests := []struct { ip uint32 expected string }{ { ip: 0, expected: "0.0.0.0", }, { ip: 0x01010101, expected: "1.1.1.1", }, { ip: 0x01030507, expected: "1.3.5.7", }, { ip: 0xFFFFFFFF, expected: "255.255.255.255", }, } for _, tt := range tests { got := intToIP4(tt.ip).String() assert.Equal(t, tt.expected, got, "IP %v not converted correctly", tt.ip) } } ================================================ FILE: idle_sweep.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "time" // idleSweep controls a periodic task that looks for idle connections and clears // them from the peer list. // NOTE: This struct is not thread-safe on its own. Calls to Start() and Stop() // should be guarded by locking ch.mutable type idleSweep struct { ch *Channel maxIdleTime time.Duration idleCheckInterval time.Duration stopCh chan struct{} started bool } // startIdleSweep starts a poller that checks for idle connections at given // intervals. func startIdleSweep(ch *Channel, opts *ChannelOptions) *idleSweep { is := &idleSweep{ ch: ch, maxIdleTime: opts.MaxIdleTime, idleCheckInterval: opts.IdleCheckInterval, } is.start() return is } // Start runs the goroutine responsible for checking idle connections. func (is *idleSweep) start() { if is.started || is.idleCheckInterval <= 0 { return } is.ch.log.WithFields( LogField{"idleCheckInterval", is.idleCheckInterval}, LogField{"maxIdleTime", is.maxIdleTime}, ).Info("Starting idle connections poller.") is.started = true is.stopCh = make(chan struct{}) go is.pollerLoop() } // Stop kills the poller checking for idle connections. func (is *idleSweep) Stop() { if !is.started { return } is.started = false is.ch.log.Info("Stopping idle connections poller.") close(is.stopCh) } func (is *idleSweep) pollerLoop() { ticker := is.ch.timeTicker(is.idleCheckInterval) for { select { case <-ticker.C: is.checkIdleConnections() case <-is.stopCh: ticker.Stop() return } } } func (is *idleSweep) checkIdleConnections() { now := is.ch.timeNow() // Acquire the read lock and examine which connections are idle. idleConnections := make([]*Connection, 0, 10) is.ch.mutable.RLock() for _, conn := range is.ch.mutable.conns { lastActivityTime := conn.getLastActivityReadTime() if sendActivityTime := conn.getLastActivityWriteTime(); lastActivityTime.Before(sendActivityTime) { lastActivityTime = sendActivityTime } if idleTime := now.Sub(lastActivityTime); idleTime >= is.maxIdleTime { idleConnections = append(idleConnections, conn) } } is.ch.mutable.RUnlock() for _, conn := range idleConnections { // It's possible that the connection is already closed when we get here. if !conn.IsActive() { continue } // We shouldn't get to a state where we have pending calls, but the connection // is idle. This either means the max-idle time is too low, or there's a stuck call. if conn.hasPendingCalls() { conn.log.Error("Skip closing idle Connection as it has pending calls.") continue } conn.close( LogField{"reason", "Idle connection closed"}, LogField{"lastActivityTimeRead", conn.getLastActivityReadTime()}, LogField{"lastActivityTimeWrite", conn.getLastActivityWriteTime()}, ) } } ================================================ FILE: idle_sweep_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "fmt" "strings" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // peerStatusListener is a test tool used to wait for connections to drop by // listening to status events from a channel. type peerStatusListener struct { changes chan struct{} } func newPeerStatusListener() *peerStatusListener { return &peerStatusListener{ changes: make(chan struct{}, 10), } } func (pl *peerStatusListener) onStatusChange(p *Peer) { pl.changes <- struct{}{} } func (pl *peerStatusListener) waitForZeroConnections(t testing.TB, channels ...*Channel) bool { for { select { case <-pl.changes: if allConnectionsClosed(channels) { return true } case <-time.After(testutils.Timeout(time.Second)): t.Fatalf("Some connections are still open: %s", connectionStatus(channels)) return false } } } func (pl *peerStatusListener) waitForZeroExchanges(t testing.TB, channels ...*Channel) { var ( isEmpty bool status []string ) if !testutils.WaitFor(100*time.Millisecond, func() bool { isEmpty, status = allExchangesEmpty(channels) return isEmpty }) { t.Fatalf("Some exchanges are still non-empty: %s", strings.Join(status, ", ")) } } func allConnectionsClosed(channels []*Channel) bool { for _, ch := range channels { if numConnections(ch) != 0 { return false } } return true } func allExchangesEmpty(channels []*Channel) (isEmpty bool, status []string) { for _, ch := range channels { n := numExchanges(ch) if n == 0 { return true, nil } status = append(status, fmt.Sprintf("%s: %d open", ch.PeerInfo().ProcessName, n)) } return false, status } func numConnections(ch *Channel) int { rootPeers := ch.RootPeers().Copy() count := 0 for _, peer := range rootPeers { in, out := peer.NumConnections() count += in + out } return count } func numExchanges(ch *Channel) int { var num int rootPeers := ch.RootPeers().Copy() for _, p := range rootPeers { state := p.IntrospectState(nil) for _, c := range state.InboundConnections { num += c.InboundExchange.Count + c.OutboundExchange.Count } for _, c := range state.OutboundConnections { num += c.InboundExchange.Count + c.OutboundExchange.Count } } return num } func connectionStatus(channels []*Channel) string { status := make([]string, 0) for _, ch := range channels { status = append(status, fmt.Sprintf("%s: %d open", ch.PeerInfo().ProcessName, numConnections(ch))) } return strings.Join(status, ", ") } // Validates that inbound idle connections are dropped. func TestServerBasedSweep(t *testing.T) { listener := newPeerStatusListener() ctx, cancel := NewContext(time.Second) defer cancel() serverTicker := testutils.NewFakeTicker() clock := testutils.NewStubClock(time.Now()) serverOpts := testutils.NewOpts(). SetTimeTicker(serverTicker.New). SetIdleCheckInterval(30 * time.Second). SetMaxIdleTime(3 * time.Minute). SetOnPeerStatusChanged(listener.onStatusChange). SetTimeNow(clock.Now). NoRelay() clientOpts := testutils.NewOpts(). SetOnPeerStatusChanged(listener.onStatusChange) testutils.WithTestServer(t, serverOpts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(clientOpts) raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) // Both server and client now have an active connection. After 3 minutes they // should be cleared out by the idle sweep. for i := 0; i < 2; i++ { clock.Elapse(1 * time.Minute) serverTicker.Tick() assert.Equal(t, 1, numConnections(ts.Server())) assert.Equal(t, 1, numConnections(client)) } // Move the clock forward and trigger the idle poller. clock.Elapse(90 * time.Second) serverTicker.Tick() listener.waitForZeroConnections(t, ts.Server(), client) }) } // Validates that outbound idle connections are dropped. func TestClientBasedSweep(t *testing.T) { listener := newPeerStatusListener() ctx, cancel := NewContext(time.Second) defer cancel() clientTicker := testutils.NewFakeTicker() clock := testutils.NewStubClock(time.Now()) serverOpts := testutils.NewOpts(). SetOnPeerStatusChanged(listener.onStatusChange). NoRelay() clientOpts := testutils.NewOpts(). SetTimeNow(clock.Now). SetTimeTicker(clientTicker.New). SetMaxIdleTime(3 * time.Minute). SetOnPeerStatusChanged(listener.onStatusChange). SetIdleCheckInterval(30 * time.Second) testutils.WithTestServer(t, serverOpts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(clientOpts) raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) // Both server and client now have an active connection. After 3 minutes they // should be cleared out by the idle sweep. clientTicker.Tick() assert.Equal(t, 1, numConnections(ts.Server())) assert.Equal(t, 1, numConnections(client)) // Move the clock forward and trigger the idle poller. clock.Elapse(180 * time.Second) clientTicker.Tick() listener.waitForZeroConnections(t, ts.Server(), client) }) } // Validates that a relay also disconnects idle connections - both inbound and // outbound. func TestRelayBasedSweep(t *testing.T) { listener := newPeerStatusListener() relayTicker := testutils.NewFakeTicker() clock := testutils.NewStubClock(time.Now()) opts := testutils.NewOpts(). SetOnPeerStatusChanged(listener.onStatusChange) relayOpts := testutils.NewOpts(). SetTimeNow(clock.Now). SetTimeTicker(relayTicker.New). SetMaxIdleTime(3 * time.Minute). SetIdleCheckInterval(30 * time.Second). SetOnPeerStatusChanged(listener.onStatusChange). SetDisableServer(). // We create our own server without the idle sweeper. SetRelayOnly() testutils.WithTestServer(t, relayOpts, func(t testing.TB, ts *testutils.TestServer) { server := ts.NewServer(opts) testutils.RegisterEcho(server, nil) // Make a call to the server via relay, which will establish connections: // Client -> Relay -> Server client := ts.NewClient(opts) testutils.AssertEcho(t, client, ts.HostPort(), server.ServiceName()) relayTicker.Tick() // Relay has 1 inbound + 1 outbound assert.Equal(t, 2, numConnections(ts.Relay())) assert.Equal(t, 1, numConnections(server)) assert.Equal(t, 1, numConnections(client)) // The relay will drop both sides of the connection after 3 minutes of inactivity. clock.Elapse(180 * time.Second) listener.waitForZeroExchanges(t, client) relayTicker.Tick() listener.waitForZeroConnections(t, ts.Relay(), server, client) }) } // Validates that pings do not keep the connection alive. func TestIdleSweepWithPings(t *testing.T) { listener := newPeerStatusListener() ctx, cancel := NewContext(time.Second) defer cancel() clientTicker := testutils.NewFakeTicker() clock := testutils.NewStubClock(time.Now()) serverOpts := testutils.NewOpts(). SetOnPeerStatusChanged(listener.onStatusChange). NoRelay() clientOpts := testutils.NewOpts(). SetTimeNow(clock.Now). SetTimeTicker(clientTicker.New). SetMaxIdleTime(3 * time.Minute). SetIdleCheckInterval(30 * time.Second). SetOnPeerStatusChanged(listener.onStatusChange) testutils.WithTestServer(t, serverOpts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(clientOpts) raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) // Generate pings every minute. for i := 0; i < 2; i++ { clock.Elapse(60 * time.Second) client.Ping(ctx, ts.HostPort()) clientTicker.Tick() assert.Equal(t, 1, numConnections(ts.Server())) assert.Equal(t, 1, numConnections(client)) } clock.Elapse(60 * time.Second) clientTicker.Tick() // Connections should still drop, regardless of the ping. listener.waitForZeroConnections(t, ts.Server(), client) }) } // Validates that when MaxIdleTime isn't set, NewChannel returns an error. func TestIdleSweepMisconfiguration(t *testing.T) { ch, err := NewChannel("svc", &ChannelOptions{ IdleCheckInterval: time.Duration(30 * time.Second), }) assert.Nil(t, ch, "NewChannel should not return a channel") assert.Error(t, err, "NewChannel should fail") } func TestIdleSweepIgnoresConnectionsWithCalls(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() clientTicker := testutils.NewFakeTicker() clock := testutils.NewStubClock(time.Now()) listener := newPeerStatusListener() // TODO: Log filtering doesn't require the message to be seen. serverOpts := testutils.NewOpts(). AddLogFilter("Skip closing idle Connection as it has pending calls.", 2). SetOnPeerStatusChanged(listener.onStatusChange). SetTimeNow(clock.Now). SetTimeTicker(clientTicker.New). SetRelayMaxTimeout(time.Hour). SetMaxIdleTime(3 * time.Minute). SetIdleCheckInterval(30 * time.Second) testutils.WithTestServer(t, serverOpts, func(t testing.TB, ts *testutils.TestServer) { var ( gotCall = make(chan struct{}) block = make(chan struct{}) ) testutils.RegisterEcho(ts.Server(), func() { close(gotCall) <-block }) clientOpts := testutils.NewOpts().SetOnPeerStatusChanged(listener.onStatusChange) // Client 1 will just ping, so we create a connection that should be closed. c1 := ts.NewClient(clientOpts) require.True(t, testutils.WaitFor(10*time.Second, func() bool { return c1.Ping(ctx, ts.HostPort()) == nil }), "Ping failed") // Client 2 will make a call that will be blocked. Wait for the call to be received. c2CallComplete := make(chan struct{}) c2 := ts.NewClient(clientOpts) go func() { testutils.AssertEcho(t, c2, ts.HostPort(), ts.ServiceName()) close(c2CallComplete) }() <-gotCall // If we are in no-relay mode, we expect 2 connections to the server (from each client). // If we are in relay mode, the relay will have the 2 connections from clients + 1 connection to the server. check := struct { ch *Channel preCloseConns int tick func() }{ ch: ts.Server(), preCloseConns: 2, tick: func() { clock.Elapse(5 * time.Minute) clientTicker.Tick() }, } if ts.HasRelay() { relay := ts.Relay() check.ch = relay check.preCloseConns++ oldTick := check.tick check.tick = func() { oldTick() // The same ticker is being used by the server and the relay // so we need to tick it twice. clientTicker.Tick() } } assert.Equal(t, check.preCloseConns, check.ch.IntrospectNumConnections(), "Expect connection to client 1 and client 2") // Let the idle checker close client 1's connection. listener.waitForZeroExchanges(t, c1) check.tick() listener.waitForZeroConnections(t, c1) // Make sure we have only a connection for client 2, which is active. assert.Equal(t, check.preCloseConns-1, check.ch.IntrospectNumConnections(), "Expect connection only to client 2") state := check.ch.IntrospectState(nil) require.Empty(t, state.InactiveConnections, "Ensure all connections are active") // Unblock the call. close(block) <-c2CallComplete // Since the idle sweep loop and message exchange run concurrently, there is // a race between the idle sweep and exchange shutdown. To mitigate this, // wait for the exchanges to shut down before triggering the idle sweep. listener.waitForZeroExchanges(t, ts.Server(), c2) check.tick() listener.waitForZeroConnections(t, ts.Server(), c2) }) } ================================================ FILE: inbound.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "errors" "fmt" "net" "time" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "golang.org/x/net/context" ) const ( systemErrorType = "system" appErrorType = "application" ) var errInboundRequestAlreadyActive = errors.New("inbound request is already active; possible duplicate client id") // handleCallReq handles an incoming call request, registering a message // exchange to receive further fragments for that call, and dispatching it in // another goroutine func (c *Connection) handleCallReq(frame *Frame) bool { now := c.timeNow() switch state := c.readState(); state { case connectionActive: break case connectionStartClose, connectionInboundClosed, connectionClosed: c.SendSystemError(frame.Header.ID, callReqSpan(frame), ErrChannelClosed) return true default: panic(fmt.Errorf("unknown connection state for call req: %v", state)) } callReq := new(callReq) callReq.id = frame.Header.ID initialFragment, err := parseInboundFragment(c.opts.FramePool, frame, callReq) if err != nil { // TODO(mmihic): Probably want to treat this as a protocol error c.log.WithFields( LogField{"header", frame.Header}, ErrField(err), ).Error("Couldn't decode initial fragment.") return true } call := new(InboundCall) call.conn = c ctx, cancel := newIncomingContext(c.baseContext, call, callReq.TimeToLive) mex, err := c.inbound.newExchange(ctx, cancel, c.opts.FramePool, callReq.messageType(), frame.Header.ID, mexChannelBufferSize) if err != nil { if err == errDuplicateMex { err = errInboundRequestAlreadyActive } c.log.WithFields(LogField{"header", frame.Header}).Error("Couldn't register exchange.") c.protocolError(frame.Header.ID, errInboundRequestAlreadyActive) return true } // Close may have been called between the time we checked the state and us creating the exchange. if c.readState() != connectionActive { mex.shutdown() return true } response := new(InboundCallResponse) response.call = call response.calledAt = now response.timeNow = c.timeNow response.span = c.extractInboundSpan(callReq) if response.span != nil { mex.ctx = opentracing.ContextWithSpan(mex.ctx, response.span) } response.mex = mex response.conn = c response.cancel = cancel response.log = c.log.WithFields(LogField{"In-Response", callReq.ID()}) response.contents = newFragmentingWriter(response.log, response, initialFragment.checksumType.New()) response.headers = transportHeaders{} response.messageForFragment = func(initial bool) message { if initial { callRes := new(callRes) callRes.Headers = response.headers callRes.ResponseCode = responseOK if response.applicationError { callRes.ResponseCode = responseApplicationError } return callRes } return new(callResContinue) } call.mex = mex call.initialFragment = initialFragment call.serviceName = string(callReq.Service) call.headers = callReq.Headers call.response = response call.log = c.log.WithFields(LogField{"In-Call", callReq.ID()}) call.messageForFragment = func(initial bool) message { return new(callReqContinue) } call.contents = newFragmentingReader(call.log, call) call.statsReporter = c.statsReporter call.createStatsTags(c.commonStatsTags) response.statsReporter = c.statsReporter response.commonStatsTags = call.commonStatsTags setResponseHeaders(call.headers, response.headers) go c.dispatchInbound(c.connID, callReq.ID(), call, frame) return false } // handleCallReqContinue handles the continuation of a call request, forwarding // it to the request channel for that request, where it can be pulled during // defragmentation func (c *Connection) handleCallReqContinue(frame *Frame) bool { if err := c.inbound.forwardPeerFrame(frame); err != nil { // If forward fails, it's due to a timeout. We can free this frame. return true } return false } func (c *Connection) handleCancel(frame *Frame) bool { c.statsReporter.IncCounter("inbound.cancels.requested", c.commonStatsTags, 1) if !c.opts.PropagateCancel { if c.log.Enabled(LogLevelDebug) { c.log.Debugf("Ignoring cancel for %v", frame.Header.ID) } return true } c.statsReporter.IncCounter("inbound.cancels.honored", c.commonStatsTags, 1) c.inbound.handleCancel(frame) // Free the frame, as it's consumed immediately. return true } // createStatsTags creates the common stats tags, if they are not already created. func (call *InboundCall) createStatsTags(connectionTags map[string]string) { call.commonStatsTags = map[string]string{ "calling-service": call.CallerName(), } for k, v := range connectionTags { call.commonStatsTags[k] = v } } // dispatchInbound ispatches an inbound call to the appropriate handler func (c *Connection) dispatchInbound(_ uint32, _ uint32, call *InboundCall, frame *Frame) { if call.log.Enabled(LogLevelDebug) { call.log.Debugf("Received incoming call for %s from %s", call.ServiceName(), c.remotePeerInfo) } if err := call.readMethod(); err != nil { call.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, ErrField(err), ).Error("Couldn't read method.") c.opts.FramePool.Release(frame) return } call.commonStatsTags["endpoint"] = call.methodString call.statsReporter.IncCounter("inbound.calls.recvd", call.commonStatsTags, 1) if span := call.response.span; span != nil { span.SetOperationName(call.methodString) } // TODO(prashant): This is an expensive way to check for cancellation. Use a heap for timeouts. go func() { select { case <-call.mex.ctx.Done(): // checking if message exchange timedout or was cancelled // only two possible errors at this step: // context.DeadlineExceeded // context.Canceled if call.mex.ctx.Err() != nil { call.mex.inboundExpired() } case <-call.mex.errCh.c: if c.log.Enabled(LogLevelDebug) { call.log.Debugf("Wait for timeout/cancellation interrupted by error: %v", call.mex.errCh.err) } // when an exchange errors out, mark the exchange as expired // and call cancel so the server handler's context is canceled // TODO: move the cancel to the parent context at connnection level call.response.cancel() call.mex.inboundExpired() } }() // Internal handlers (e.g., introspection) trump all other user-registered handlers on // the "tchannel" name. if call.ServiceName() == "tchannel" { if h := c.internalHandlers.find(call.Method()); h != nil { h.Handle(call.mex.ctx, call) return } } c.handler.Handle(call.mex.ctx, call) } // An InboundCall is an incoming call from a peer type InboundCall struct { reqResReader conn *Connection response *InboundCallResponse serviceName string method []byte methodString string headers transportHeaders statsReporter StatsReporter commonStatsTags map[string]string } // ServiceName returns the name of the service being called func (call *InboundCall) ServiceName() string { return call.serviceName } // Method returns the method being called func (call *InboundCall) Method() []byte { return call.method } // MethodString returns the method being called as a string. func (call *InboundCall) MethodString() string { return call.methodString } // Format the format of the request from the ArgScheme transport header. func (call *InboundCall) Format() Format { return Format(call.headers[ArgScheme]) } // CallerName returns the caller name from the CallerName transport header. func (call *InboundCall) CallerName() string { return call.headers[CallerName] } // ShardKey returns the shard key from the ShardKey transport header. func (call *InboundCall) ShardKey() string { return call.headers[ShardKey] } // RoutingKey returns the routing key from the RoutingKey transport header. func (call *InboundCall) RoutingKey() string { return call.headers[RoutingKey] } // RoutingDelegate returns the routing delegate from the RoutingDelegate transport header. func (call *InboundCall) RoutingDelegate() string { return call.headers[RoutingDelegate] } // LocalPeer returns the local peer information for this call. func (call *InboundCall) LocalPeer() LocalPeerInfo { return call.conn.localPeerInfo } // RemotePeer returns the remote peer information for this call. func (call *InboundCall) RemotePeer() PeerInfo { return call.conn.RemotePeerInfo() } // Connection returns the underlying raw net connection. func (call *InboundCall) Connection() net.Conn { return call.conn.conn } // CallOptions returns a CallOptions struct suitable for forwarding a request. func (call *InboundCall) CallOptions() *CallOptions { return &CallOptions{ CallerName: call.CallerName(), Format: call.Format(), ShardKey: call.ShardKey(), RoutingDelegate: call.RoutingDelegate(), RoutingKey: call.RoutingKey(), } } // Reads the entire method name (arg1) from the request stream. func (call *InboundCall) readMethod() error { var arg1 []byte if err := NewArgReader(call.arg1Reader()).Read(&arg1); err != nil { return call.failed(err) } call.method = arg1 call.methodString = string(arg1) return nil } // Arg2Reader returns an ArgReader to read the second argument. // The ReadCloser must be closed once the argument has been read. func (call *InboundCall) Arg2Reader() (ArgReader, error) { return call.arg2Reader() } // Arg3Reader returns an ArgReader to read the last argument. // The ReadCloser must be closed once the argument has been read. func (call *InboundCall) Arg3Reader() (ArgReader, error) { return call.arg3Reader() } // Response provides access to the InboundCallResponse object which can be used // to write back to the calling peer func (call *InboundCall) Response() *InboundCallResponse { if call.err != nil { // While reading Thrift, we cannot distinguish between malformed Thrift and other errors, // and so we may try to respond with a bad request. We should ensure that the response // is marked as failed if the request has failed so that we don't try to shutdown the exchange // a second time. call.response.err = call.err } return call.response } func (call *InboundCall) doneReading(unexpected error) {} // An InboundCallResponse is used to send the response back to the calling peer type InboundCallResponse struct { reqResWriter call *InboundCall cancel context.CancelFunc // calledAt is the time the inbound call was routed to the application. calledAt time.Time timeNow func() time.Time applicationError bool systemError bool headers transportHeaders span opentracing.Span statsReporter StatsReporter commonStatsTags map[string]string } // SendSystemError returns a system error response to the peer. The call is considered // complete after this method is called, and no further data can be written. func (response *InboundCallResponse) SendSystemError(err error) error { if response.err != nil { return response.err } // Fail all future attempts to read fragments response.state = reqResWriterComplete response.systemError = true response.setSpanErrorDetails(err) response.doneSending() response.call.releasePreviousFragment() span := CurrentSpan(response.mex.ctx) return response.conn.SendSystemError(response.mex.msgID, *span, err) } // SetApplicationError marks the response as being an application error. This method can // only be called before any arguments have been sent to the calling peer. func (response *InboundCallResponse) SetApplicationError() error { if response.state > reqResWriterPreArg2 { return response.failed(errReqResWriterStateMismatch{ state: response.state, expectedState: reqResWriterPreArg2, }) } response.applicationError = true response.setSpanErrorDetails(nil) return nil } // Blackhole indicates no response will be sent, and cleans up any resources // associated with this request. This allows for services to trigger a timeout in // clients without holding on to any goroutines on the server. func (response *InboundCallResponse) Blackhole() { response.cancel() } // Arg2Writer returns a WriteCloser that can be used to write the second argument. // The returned writer must be closed once the write is complete. func (response *InboundCallResponse) Arg2Writer() (ArgWriter, error) { if err := NewArgWriter(response.arg1Writer()).Write(nil); err != nil { return nil, err } return response.arg2Writer() } // Arg3Writer returns a WriteCloser that can be used to write the last argument. // The returned writer must be closed once the write is complete. func (response *InboundCallResponse) Arg3Writer() (ArgWriter, error) { return response.arg3Writer() } // setSpanErrorDetails sets the span tags for the error type. func (response *InboundCallResponse) setSpanErrorDetails(err error) { if span := response.span; span != nil { if response.applicationError || response.systemError { errorType := appErrorType if response.systemError { errorType = systemErrorType // if the error is a system error, set the error code as a span tag span.SetTag("rpc.tchannel.system_error_code", GetSystemErrorCode(err).MetricsKey()) } span.SetTag("rpc.tchannel.error_type", errorType) } } } // doneSending shuts down the message exchange for this call. // For incoming calls, the last message is sending the call response. func (response *InboundCallResponse) doneSending() { // TODO(prashant): Move this to when the message is actually being sent. now := response.timeNow() if span := response.span; span != nil { if response.applicationError || response.systemError { ext.Error.Set(span, true) } span.FinishWithOptions(opentracing.FinishOptions{FinishTime: now}) } latency := now.Sub(response.calledAt) response.statsReporter.RecordTimer("inbound.calls.latency", response.commonStatsTags, latency) if response.systemError { // TODO(prashant): Report the error code type as per metrics doc and enable. // response.statsReporter.IncCounter("inbound.calls.system-errors", response.commonStatsTags, 1) } else if response.applicationError { response.statsReporter.IncCounter("inbound.calls.app-errors", response.commonStatsTags, 1) } else { response.statsReporter.IncCounter("inbound.calls.success", response.commonStatsTags, 1) } // Cancel the context since the response is complete. response.cancel() // The message exchange is still open if there are no errors, call shutdown. if response.err == nil { response.mex.shutdown() } } ================================================ FILE: inbound_internal_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "testing" "time" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" ) type statsReporter struct{} func (w *statsReporter) IncCounter(name string, tags map[string]string, value int64) { } func (w *statsReporter) UpdateGauge(name string, tags map[string]string, value int64) { } func (w *statsReporter) RecordTimer(name string, tags map[string]string, d time.Duration) { } type testCase struct { name string injectedError error systemError bool applicationError bool expectedSpanError bool expectedSpanErrorType string expectedSystemErrorKey string } func TestTracingSpanError(t *testing.T) { var ( systemError = NewSystemError(ErrCodeBusy, "foo") applicationError = fmt.Errorf("application") ) testCases := []testCase{ { name: "ApplicationError", injectedError: applicationError, systemError: false, applicationError: true, expectedSpanError: true, expectedSpanErrorType: "application", expectedSystemErrorKey: "", }, { name: "SystemError", injectedError: systemError, systemError: true, applicationError: false, expectedSpanError: true, expectedSpanErrorType: "system", expectedSystemErrorKey: GetSystemErrorCode(systemError).MetricsKey(), }, } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { var ( parsedSpan *mocktracer.MockSpan tracer = mocktracer.New() callResp = &InboundCallResponse{ span: tracer.StartSpan("test"), statsReporter: &statsReporter{}, reqResWriter: reqResWriter{ err: tt.injectedError, }, applicationError: tt.applicationError, systemError: tt.systemError, timeNow: time.Now, cancel: func() {}, } ) callResp.setSpanErrorDetails(tt.injectedError) callResp.doneSending() parsedSpan = callResp.span.(*mocktracer.MockSpan) assert.Equal(t, tt.expectedSpanError, parsedSpan.Tag("error").(bool)) if tt.expectedSystemErrorKey == "" { assert.Nil(t, parsedSpan.Tag("rpc.tchannel.system_error_code")) } else { assert.Equal(t, tt.expectedSystemErrorKey, parsedSpan.Tag("rpc.tchannel.system_error_code").(string)) } assert.Equal(t, tt.expectedSpanErrorType, parsedSpan.Tag("rpc.tchannel.error_type").(string)) }) } } ================================================ FILE: inbound_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "strings" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/jaeger-client-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) func TestSpanReportingForErrors(t *testing.T) { injectedSystemError := ErrTimeout tests := []struct { name string method string systemErr bool applicationErr bool }{ { name: "System Error", method: "system-error", systemErr: true, applicationErr: false, }, { name: "Application Error", method: "app-error", systemErr: false, applicationErr: true, }, { name: "No Error", method: "no-error", systemErr: false, applicationErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // We use a jaeger tracer here and not Mocktracer: because jaeger supports // zipkin format which is essential for inbound span extraction jaegerReporter := jaeger.NewInMemoryReporter() jaegerTracer, jaegerCloser := jaeger.NewTracer(testutils.DefaultServerName, jaeger.NewConstSampler(true), jaegerReporter) defer jaegerCloser.Close() opts := &testutils.ChannelOpts{ ChannelOptions: ChannelOptions{Tracer: jaegerTracer}, } testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // Register handler that returns app error ts.RegisterFunc("app-error", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ IsErr: true, }, nil }) // Register handler that returns system error ts.RegisterFunc("system-error", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{ SystemErr: injectedSystemError, }, nil }) // Register handler that returns no error ts.RegisterFunc("no-error", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { return &raw.Res{}, nil }) ctx, cancel := NewContext(20 * time.Second) defer cancel() clientCh := ts.NewClient(opts) defer clientCh.Close() // Make a new call, which should fail _, _, resp, err := raw.Call(ctx, clientCh, ts.HostPort(), ts.ServiceName(), tt.method, []byte("Arg2"), []byte("Arg3")) if tt.systemErr { // Providing 'got: %q' is necessary since SystemErrCode is a type alias of byte; testify's // failed test ouput would otherwise print out hex codes. assert.Equal(t, injectedSystemError, err, "expected cancelled error code, got: %q", err) } else { assert.Nil(t, err, "expected no system error code") } if tt.applicationErr { assert.True(t, resp.ApplicationError(), "Call(%v) check application error") } else if !tt.systemErr { assert.False(t, resp.ApplicationError(), "Call(%v) check application error") } }) // We should have 4 spans, 2 for client and 2 for server assert.Equal(t, len(jaegerReporter.GetSpans()), 4) for _, span := range jaegerReporter.GetSpans() { if span.(*jaeger.Span).Tags()["span.kind"] == "server" { assert.Equal(t, span.(*jaeger.Span).Tags()["error"], true) if tt.applicationErr { assert.Equal(t, span.(*jaeger.Span).Tags()["rpc.tchannel.error_type"], "application") assert.Nil(t, span.(*jaeger.Span).Tags()["rpc.tchannel.system_error_code"]) } else if tt.systemErr { assert.Equal(t, span.(*jaeger.Span).Tags()["rpc.tchannel.error_type"], "system") assert.Equal(t, span.(*jaeger.Span).Tags()["rpc.tchannel.system_error_code"], GetSystemErrorCode(injectedSystemError).MetricsKey()) } else { assert.Nil(t, span.(*jaeger.Span).Tags()["rpc.tchannel.error_type"]) assert.Nil(t, span.(*jaeger.Span).Tags()["rpc.tchannel.system_error_code"]) } } } jaegerReporter.Reset() }) } } func TestActiveCallReq(t *testing.T) { t.Skip("Test skipped due to unreliable way to test for protocol errors") ctx, cancel := NewContext(time.Second) defer cancel() // Note: This test cannot use log verification as the duplicate ID causes a log. // It does not use a verified server, as it leaks a message exchange due to the // modification of IDs in the relay. opts := testutils.NewOpts().DisableLogVerification() testutils.WithServer(t, opts, func(ch *Channel, hostPort string) { gotCall := make(chan struct{}) unblock := make(chan struct{}) testutils.RegisterFunc(ch, "blocked", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { gotCall <- struct{}{} <-unblock return &raw.Res{}, nil }) relayFunc := func(outgoing bool, frame *Frame) *Frame { if outgoing && frame.Header.ID == 3 { frame.Header.ID = 2 } return frame } relayHostPort, closeRelay := testutils.FrameRelay(t, hostPort, relayFunc) defer closeRelay() firstComplete := make(chan struct{}) go func() { // This call will block until we close unblock. raw.Call(ctx, ch, relayHostPort, ch.PeerInfo().ServiceName, "blocked", nil, nil) close(firstComplete) }() // Wait for the first call to be received by the server <-gotCall // Make a new call, which should fail _, _, _, err := raw.Call(ctx, ch, relayHostPort, ch.PeerInfo().ServiceName, "blocked", nil, nil) assert.Error(t, err, "Expect error") assert.True(t, strings.Contains(err.Error(), "already active"), "expected already active error, got %v", err) close(unblock) <-firstComplete }) } func TestInboundConnection(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() // Disable relay since relays hide host:port on outbound calls. opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s2 := ts.NewServer(nil) ts.RegisterFunc("test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { c, rawConn := InboundConnection(CurrentCall(ctx)) assert.Equal(t, s2.PeerInfo().HostPort, c.RemotePeerInfo().HostPort, "Unexpected host port") assert.NotNil(t, rawConn, "unexpected connection") return &raw.Res{}, nil }) _, _, _, err := raw.Call(ctx, s2, ts.HostPort(), ts.ServiceName(), "test", nil, nil) require.NoError(t, err, "Call failed") }) } func TestInboundConnection_CallOptions(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() testutils.WithTestServer(t, nil, func(t testing.TB, server *testutils.TestServer) { server.RegisterFunc("test", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { assert.Equal(t, "client", CurrentCall(ctx).CallerName(), "Expected caller name to be passed through") return &raw.Res{}, nil }) backendName := server.ServiceName() proxyCh := server.NewServer(&testutils.ChannelOpts{ServiceName: "proxy"}) defer proxyCh.Close() subCh := proxyCh.GetSubChannel(backendName) subCh.SetHandler(HandlerFunc(func(ctx context.Context, inbound *InboundCall) { outbound, err := proxyCh.BeginCall(ctx, server.HostPort(), backendName, inbound.MethodString(), inbound.CallOptions()) require.NoError(t, err, "Create outbound call failed") arg2, arg3, _, err := raw.WriteArgs(outbound, []byte("hello"), []byte("world")) require.NoError(t, err, "Write outbound call failed") require.NoError(t, raw.WriteResponse(inbound.Response(), &raw.Res{ Arg2: arg2, Arg3: arg3, }), "Write response failed") })) clientCh := server.NewClient(&testutils.ChannelOpts{ ServiceName: "client", }) defer clientCh.Close() _, _, _, err := raw.Call(ctx, clientCh, proxyCh.PeerInfo().HostPort, backendName, "test", nil, nil) require.NoError(t, err, "Call through proxy failed") }) } func TestCallOptionsPropogated(t *testing.T) { const handler = "handler" giveCallOpts := CallOptions{ Format: JSON, CallerName: "test-caller-name", ShardKey: "test-shard-key", RoutingKey: "test-routing-key", RoutingDelegate: "test-routing-delegate", } var gotCallOpts *CallOptions testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(HandlerFunc(func(ctx context.Context, inbound *InboundCall) { gotCallOpts = inbound.CallOptions() err := raw.WriteResponse(inbound.Response(), &raw.Res{}) assert.NoError(t, err, "write response failed") }), handler) ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() call, err := ts.Server().BeginCall(ctx, ts.HostPort(), ts.ServiceName(), handler, &giveCallOpts) require.NoError(t, err, "could not call test server") _, _, _, err = raw.WriteArgs(call, nil, nil) require.NoError(t, err, "could not write args") assert.Equal(t, &giveCallOpts, gotCallOpts) }) } func TestBlackhole(t *testing.T) { ctx, cancel := NewContext(testutils.Timeout(time.Hour)) testutils.WithTestServer(t, nil, func(t testing.TB, server *testutils.TestServer) { serviceName := server.ServiceName() handlerName := "test-handler" server.Register(HandlerFunc(func(ctx context.Context, inbound *InboundCall) { // cancel client context in handler so the client can return after being blackholed defer cancel() c, _ := InboundConnection(inbound) require.NotNil(t, c) state := c.IntrospectState(&IntrospectionOptions{}) require.Equal(t, 1, state.InboundExchange.Count, "expected exactly one inbound exchange") // blackhole request inbound.Response().Blackhole() // give time for exchange to cleanup require.True(t, testutils.WaitFor(10*time.Millisecond, func() bool { state = c.IntrospectState(&IntrospectionOptions{}) return state.InboundExchange.Count == 0 }), "expected no inbound exchanges", ) }), handlerName) clientCh := server.NewClient(nil) defer clientCh.Close() _, _, _, err := raw.Call(ctx, clientCh, server.HostPort(), serviceName, handlerName, nil, nil) require.Error(t, err, "expected call error") errCode := GetSystemErrorCode(err) // Providing 'got: %q' is necessary since SystemErrCode is a type alias of byte; testify's // failed test output would otherwise print out hex codes. assert.Equal(t, ErrCodeCancelled, errCode, "expected cancelled error code, got: %q", errCode) }) } ================================================ FILE: incoming_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" ) func TestPeersIncomingConnection(t *testing.T) { newService := func(svcName string) (*Channel, string) { ch, _, hostPort := NewServer(t, &testutils.ChannelOpts{ServiceName: svcName}) return ch, hostPort } opts := testutils.NewOpts().NoRelay() WithVerifiedServer(t, opts, func(ch *Channel, hostPort string) { doPing := func(ch *Channel) { ctx, cancel := NewContext(time.Second) defer cancel() assert.NoError(t, ch.Ping(ctx, hostPort), "Ping failed") } hyperbahnSC := ch.GetSubChannel("hyperbahn") ringpopSC := ch.GetSubChannel("ringpop", Isolated) hyperbahn, hyperbahnHostPort := newService("hyperbahn") defer hyperbahn.Close() ringpop, ringpopHostPort := newService("ringpop") defer ringpop.Close() doPing(hyperbahn) doPing(ringpop) // The root peer list should contain all incoming connections. rootPeers := ch.RootPeers().Copy() assert.NotNil(t, rootPeers[hyperbahnHostPort], "missing hyperbahn peer") assert.NotNil(t, rootPeers[ringpopHostPort], "missing ringpop peer") for _, sc := range []Registrar{ch, hyperbahnSC, ringpopSC} { _, err := sc.Peers().Get(nil) assert.Equal(t, ErrNoPeers, err, "incoming connections should not be added to non-root peer list") } // verify number of peers/connections on the client side serverState := ch.IntrospectState(nil).RootPeers serverHostPort := ch.PeerInfo().HostPort assert.Equal(t, len(serverState), 2, "Incorrect peer count") for _, client := range []*Channel{ringpop, hyperbahn} { clientPeerState := client.IntrospectState(nil).RootPeers clientHostPort := client.PeerInfo().HostPort assert.Equal(t, len(clientPeerState), 1, "Incorrect peer count") assert.Equal(t, len(clientPeerState[serverHostPort].OutboundConnections), 1, "Incorrect outbound connection count") assert.Equal(t, len(clientPeerState[serverHostPort].InboundConnections), 0, "Incorrect inbound connection count") assert.Equal(t, len(serverState[clientHostPort].InboundConnections), 1, "Incorrect inbound connection count") assert.Equal(t, len(serverState[clientHostPort].OutboundConnections), 0, "Incorrect outbound connection count") } // In future when connections send a service name, we should be able to // check that a new connection containing a service name for an isolated // subchannel is only added to the isolated subchannels' peers, but all // other incoming connections are added to the shared peer list. }) } ================================================ FILE: init_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "io" "net" "runtime" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func writeMessage(w io.Writer, msg message) error { f := NewFrame(MaxFramePayloadSize) if err := f.write(msg); err != nil { return err } return f.WriteOut(w) } func readFrame(r io.Reader) (*Frame, error) { f := NewFrame(MaxFramePayloadSize) return f, f.ReadIn(r) } func TestUnexpectedInitReq(t *testing.T) { tests := []struct { name string initMsg message expectedError errorMessage }{ { name: "bad version", initMsg: &initReq{initMessage{id: 1, Version: 0x1, initParams: initParams{ InitParamHostPort: "0.0.0.0:0", InitParamProcessName: "test", }}}, expectedError: errorMessage{ id: 1, errCode: ErrCodeProtocol, }, }, { name: "missing InitParamHostPort", initMsg: &initReq{initMessage{id: 2, Version: CurrentProtocolVersion, initParams: initParams{ InitParamProcessName: "test", }}}, expectedError: errorMessage{ id: 2, errCode: ErrCodeProtocol, }, }, { name: "missing InitParamProcessName", initMsg: &initReq{initMessage{id: 3, Version: CurrentProtocolVersion, initParams: initParams{ InitParamHostPort: "0.0.0.0:0", }}}, expectedError: errorMessage{ id: 3, errCode: ErrCodeProtocol, }, }, { name: "unexpected message type", initMsg: &pingReq{ id: 1, }, expectedError: errorMessage{ id: 1, errCode: ErrCodeProtocol, }, }, } for _, tt := range tests { ch, err := NewChannel("test", nil) require.NoError(t, err) defer ch.Close() require.NoError(t, ch.ListenAndServe("127.0.0.1:0")) hostPort := ch.PeerInfo().HostPort conn, err := net.Dial("tcp", hostPort) require.NoError(t, err) conn.SetReadDeadline(time.Now().Add(time.Second)) if !assert.NoError(t, writeMessage(conn, tt.initMsg), "write to conn failed") { continue } f, err := readFrame(conn) if !assert.NoError(t, err, "read frame failed") { continue } assert.Equal(t, messageTypeError, f.Header.messageType) var errMsg errorMessage if !assert.NoError(t, f.read(&errMsg), "parse frame to errorMessage") { continue } assert.Equal(t, tt.expectedError.ID(), f.Header.ID, "test %v got bad ID", tt.name) assert.Equal(t, tt.expectedError.errCode, errMsg.errCode, "test %v got bad code", tt.name) assert.NoError(t, conn.Close(), "closing connection failed") } } func TestUnexpectedInitRes(t *testing.T) { validParams := initParams{ InitParamHostPort: "0.0.0.0:0", InitParamProcessName: "tchannel-go.test", } tests := []struct { msg message errMsg string }{ { msg: &initRes{initMessage{ id: 1, Version: CurrentProtocolVersion - 1, initParams: validParams, }}, errMsg: "unsupported protocol version", }, { msg: &initRes{initMessage{ id: 1, Version: CurrentProtocolVersion + 1, initParams: validParams, }}, errMsg: "unsupported protocol version", }, { msg: &initRes{initMessage{ id: 1, Version: CurrentProtocolVersion, }}, errMsg: "header host_port is required", }, { msg: &initRes{initMessage{ id: 1, Version: CurrentProtocolVersion, initParams: initParams{ InitParamHostPort: "0.0.0.0:0", }, }}, errMsg: "header process_name is required", }, } for _, tt := range tests { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "net.Listen failed") defer ln.Close() done := make(chan struct{}) go func() { defer close(done) ch, err := NewChannel("test", nil) require.NoError(t, err) defer ch.Close() ctx, cancel := NewContext(time.Second) defer cancel() _, err = ch.Peers().GetOrAdd(ln.Addr().String()).GetConnection(ctx) if !assert.Error(t, err, "Expected GetConnection to fail") { return } assert.Equal(t, ErrCodeProtocol, GetSystemErrorCode(err), "Unexpected error code, got error: %v", err) assert.Contains(t, err.Error(), tt.errMsg) }() conn, err := ln.Accept() require.NoError(t, err, "Failed to accept connection") // Read the frame and verify that it's an initReq. f, err := readFrame(conn) require.NoError(t, err, "read frame failed") if !assert.Equal(t, messageTypeInitReq, f.messageType(), "Expected first message to be initReq") { continue } // Write out the specified initRes wait for the channel to get an error. assert.NoError(t, writeMessage(conn, tt.msg), "write initRes failed") <-done } } func TestHandleInitReqNewVersion(t *testing.T) { ch, err := NewChannel("test", nil) require.NoError(t, err) defer ch.Close() require.NoError(t, ch.ListenAndServe("127.0.0.1:0")) hostPort := ch.PeerInfo().HostPort conn, err := net.Dial("tcp", hostPort) require.NoError(t, err) defer conn.Close() conn.SetReadDeadline(time.Now().Add(time.Second)) initMsg := &initReq{initMessage{id: 1, Version: CurrentProtocolVersion + 3, initParams: initParams{ InitParamHostPort: "0.0.0.0:0", InitParamProcessName: "test", }}} require.NoError(t, writeMessage(conn, initMsg), "write to conn failed") // Verify we get an initRes back with the current protocol version. f, err := readFrame(conn) require.NoError(t, err, "expected frame with init res") var msg initRes require.NoError(t, f.read(&msg), "could not read init res from frame") if assert.Equal(t, messageTypeInitRes, f.Header.messageType, "expected initRes, got %v", f.Header.messageType) { assert.Equal(t, initRes{ initMessage: initMessage{ Version: CurrentProtocolVersion, initParams: initParams{ InitParamHostPort: ch.PeerInfo().HostPort, InitParamProcessName: ch.PeerInfo().ProcessName, InitParamTChannelLanguage: "go", InitParamTChannelLanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), InitParamTChannelVersion: VersionInfo, }, }, }, msg, "unexpected init res") } } // TestHandleInitRes ensures that a Connection is ready to handle messages immediately // after receiving an InitRes. func TestHandleInitRes(t *testing.T) { l := newListener(t) listenerComplete := make(chan struct{}) go func() { conn, err := l.Accept() require.NoError(t, err, "l.Accept failed") // The connection should be kept open until the test has completed running. defer conn.Close() defer func() { listenerComplete <- struct{}{} }() f, err := readFrame(conn) require.NoError(t, err, "readFrame failed") assert.Equal(t, messageTypeInitReq, f.Header.messageType, "expected initReq message") var msg initReq require.NoError(t, f.read(&msg), "read frame into initMsg failed") initRes := initRes{msg.initMessage} initRes.initMessage.id = f.Header.ID require.NoError(t, writeMessage(conn, &initRes), "write initRes failed") require.NoError(t, writeMessage(conn, &pingReq{noBodyMsg{}, 10}), "write pingReq failed") f, err = readFrame(conn) require.NoError(t, err, "readFrame failed") assert.Equal(t, messageTypePingRes, f.Header.messageType, "expected pingRes message") }() ch, err := NewChannel("test-svc", nil) require.NoError(t, err, "NewChannel failed") defer ch.Close() ctx, cancel := NewContext(time.Second) defer cancel() _, err = ch.Peers().GetOrAdd(l.Addr().String()).GetConnection(ctx) require.NoError(t, err, "GetConnection failed") <-listenerComplete } func TestInitReqGetsError(t *testing.T) { l := newListener(t) listenerComplete := make(chan struct{}) connectionComplete := make(chan struct{}) go func() { defer func() { listenerComplete <- struct{}{} }() conn, err := l.Accept() require.NoError(t, err, "l.Accept failed") defer conn.Close() f, err := readFrame(conn) require.NoError(t, err, "readFrame failed") assert.Equal(t, messageTypeInitReq, f.Header.messageType, "expected initReq message") err = writeMessage(conn, &errorMessage{ id: f.Header.ID, errCode: ErrCodeBadRequest, message: "invalid host:port", }) assert.NoError(t, err, "Failed to write errorMessage") // Wait till GetConnection returns before closing the connection. <-connectionComplete }() logOut := &bytes.Buffer{} ch, err := NewChannel("test-svc", &ChannelOptions{Logger: NewLevelLogger(NewLogger(logOut), LogLevelWarn)}) require.NoError(t, err, "NewClient failed") defer ch.Close() ctx, cancel := NewContext(time.Second) defer cancel() _, err = ch.Peers().GetOrAdd(l.Addr().String()).GetConnection(ctx) expectedErr := NewSystemError(ErrCodeBadRequest, "invalid host:port") assert.Equal(t, expectedErr, err, "Error mismatch") assert.Contains(t, logOut.String(), "[E] Failed during connection handshake.", "Message should be logged") assert.Contains(t, logOut.String(), "tchannel error ErrCodeBadRequest: invalid host:port", "Error should be logged") close(connectionComplete) <-listenerComplete } func newListener(t *testing.T) net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "Listen failed") return l } ================================================ FILE: internal/argreader/empty.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package argreader import ( "fmt" "io" "sync" ) var _bufPool = sync.Pool{ New: func() interface{} { b := make([]byte, 128) return &b }, } // EnsureEmpty ensures that the specified reader is empty. If the reader is // not empty, it returns an error with the specified stage in the message. func EnsureEmpty(r io.Reader, stage string) error { buf := _bufPool.Get().(*[]byte) defer _bufPool.Put(buf) n, err := r.Read(*buf) if n > 0 { return fmt.Errorf("found unexpected bytes after %s, found (upto 128 bytes): %x", stage, (*buf)[:n]) } if err == io.EOF { return nil } return err } ================================================ FILE: internal/argreader/empty_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package argreader import ( "bytes" "testing" "github.com/uber/tchannel-go/testutils/testreader" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEnsureEmptySuccess(t *testing.T) { reader := bytes.NewReader(nil) err := EnsureEmpty(reader, "success") require.NoError(t, err, "ensureEmpty should succeed with empty reader") } func TestEnsureEmptyHasBytes(t *testing.T) { reader := bytes.NewReader([]byte{1, 2, 3}) err := EnsureEmpty(reader, "T") require.Error(t, err, "ensureEmpty should fail when there's bytes") assert.Equal(t, err.Error(), "found unexpected bytes after T, found (upto 128 bytes): 010203") } func TestEnsureEmptyError(t *testing.T) { control, reader := testreader.ChunkReader() control <- nil close(control) err := EnsureEmpty(reader, "has bytes") require.Error(t, err, "ensureEmpty should fail when there's an error") assert.Equal(t, testreader.ErrUser, err, "Unexpected error") } ================================================ FILE: internal/testcert/testcert.go ================================================ // Copyright (c) 2022 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testcert // TestCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. // generated from src/crypto/tls: // go run "$(go env GOROOT)/src/crypto/tls/generate_cert.go" --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var TestCert = []byte(`-----BEGIN CERTIFICATE----- MIIDOTCCAiGgAwIBAgIQD2X8uKDzMVRc0crgmNX/0zANBgkqhkiG9w0BAQsFADAS MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A MIIBCgKCAQEAqvPofDY9ItZCO7TWb/Symnb38SuuJt4o6iTNlsE0wFPfWdYlE760 PRW2rUqE7t0M2AQwHD3OWPpzLZcqZA2aSKEyx/GmQuNUYN87idYW1JhbxD3zn14P fflcf9s3PiWscnOM9xmPOkSvCptG9IdOs2l1TqmM91+z6AIS/M1yJvETcLJjZqTE v5YK8RuSdTk1prgKA25HLSnwn8JFkG3L9lc0y96W2gwcW5j3+RmVie+k57pa67LD aD2cMBDXcI+OFlDxecjtuaKJBZtbU/0QS0ehc9XXCgRvwUlg1T/MDb5Oi5z+rhuK CP2aLd7QvTYiSgw3J0f/g52QWdBzkBaZFQIDAQABo4GIMIGFMA4GA1UdDwEB/wQE AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud DgQWBBQqXSCk6h8ksO7U+3NH2nsM0GPkRjAuBgNVHREEJzAlggtleGFtcGxlLmNv bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAf4DP yoGZ26s5IkBK5iJBpIFtIWnejBSPc7gdFmQsFb9qjRt7kQf7bKLkER0FLFmq3I0f lsmWcYwvuLZSCQppxNB1lzcWqiE9LkHrO1wNJqcipPtOwhg9VYLgwi2BJd6mMr++ EHJntBgGpsvM4nqSanjjMlaE1ZPP2flt8/xSnikY78P7aYmHPL4xY5Al8zI09H1o pc96r62fgMPMSDibhF5tqz5nK7Olt2Jd/alHd7LMzVOQw2DfCaBrj8OPO2J4ppvu rqJ+Izqv7kZpwU1Ye6dFG/F8TOp1iWhkCoVR17FP6dqY1BZLfxiz3YsoS+2XVh3z CTWY1J1Aj1WiEVBTfg== -----END CERTIFICATE-----`) // TestKey is the private key for TestCert. var TestKey = []byte(`-----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCq8+h8Nj0i1kI7 tNZv9LKadvfxK64m3ijqJM2WwTTAU99Z1iUTvrQ9FbatSoTu3QzYBDAcPc5Y+nMt lypkDZpIoTLH8aZC41Rg3zuJ1hbUmFvEPfOfXg99+Vx/2zc+Jaxyc4z3GY86RK8K m0b0h06zaXVOqYz3X7PoAhL8zXIm8RNwsmNmpMS/lgrxG5J1OTWmuAoDbkctKfCf wkWQbcv2VzTL3pbaDBxbmPf5GZWJ76TnulrrssNoPZwwENdwj44WUPF5yO25ookF m1tT/RBLR6Fz1dcKBG/BSWDVP8wNvk6LnP6uG4oI/Zot3tC9NiJKDDcnR/+DnZBZ 0HOQFpkVAgMBAAECggEATDuyW9mwD53uMUPmMEy1bK5KyNBKu+hr5GX/DBAiXvXH 7v7Qz+pF48uQB9zoRMBsXtQXRDDHmOQugpEbhTyPpX3E8GaxVribQwupOEExMyKy IWPjBRlj3TBa8GUoUF1qditTHEnYlgpU6GzwClFgZh9MAYUYaKPTzU1HfFZ9ZiF2 jZB841HorsAJzbTnKXpHSK51GZ0ecOPGhRMkImsAskuI/EY5RBUZJmI9vVrs0pIu OO9TcAvSs9tNXfM8YrJwZVMG11qiCcvfHD3VuYhsYEOvCsjxSmRp4DCYlISTlUr+ LXv7VdhGMoeSdQVQqpqPF9kqkghfOzQFQ9ppzw6iDQKBgQDSmPNIY0f7nZH4diir A0WUl7QzzUyf2qX4UrYzgGHufEfanTlrS3sTAdEkK85oxfNygLBXYmxtrzcQWVFD gx5cXDHaH6ZVoZxSRrDyO37vrVv76NSrOH3yqq9j8gytf3M74dTcunMVOGGdx1Zi D/AQ05KpjdKmhBDyCdGcHvXAqwKBgQDPzu8YdP56w3VNkPAlXRLZu9g5eZHj4uPF NRexV8BdbQ8EVu3KnIjzCSUSjPdGDN18ycgTrU0AzQ8MxQE8rqebs/otPTKsYJt4 SwR/Ol+lDC+lGdSTREUu677MPE0buAce0UBQ9RtWoYUEsNEI6sFqReaCqmri55tm ioM4T3qNPwKBgQCQU8YXDANfC2PodYH1gW6EIVucTMyAmSY5guXfcdKr0Hyl9C5P vBECu7ILKgJxh4gKJuuzV36bxQLlr3Cj5g4+meiIZjxmXzV0pYHK4L9jntl1UOG+ 3h5i2lsNEetiVAAzP9fT1evc1SEBMoWe+vE5duYCUXHWMJg0aEpAxm8BtQKBgQCX BYBlecDnXt0E/exIexeT/RvqyRrpTp7RVwBc9bTrMLLVKIev04sDdQXoMWITGo5s fghVpIBtsJjbYuC/RP6x/V43Ol51P9A83+fovnd77xtBFUCTte3BZ7pFmx0+o8Mo 9lGThE3V65RMEGQZ4uGlZh9bnpYHSOJ65vbuGXSq6QKBgHthfDeAsW7V4JIm0IG+ sEkFjGvYhyngDbOKMSf9YN3YuuuLPawHQJYe7gmH4p/Wry+oUcF8t5ddhwLd63xz q4LAT9EgEvfLEbMnxjvLHUG/eeRx6zqCf54+KHfGCcooOI4kbI7lkQglLq5DWDe2 4n6AEKY0aVWJ1zN9B/vaJMZM -----END PRIVATE KEY-----`) ================================================ FILE: introspection.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "encoding/json" "fmt" "runtime" "sort" "strconv" "time" "golang.org/x/net/context" ) // IntrospectionOptions are the options used when introspecting the Channel. type IntrospectionOptions struct { // IncludeExchanges will include all the IDs in the message exchanges. IncludeExchanges bool `json:"includeExchanges"` // IncludeEmptyPeers will include peers, even if they have no connections. IncludeEmptyPeers bool `json:"includeEmptyPeers"` // IncludeTombstones will include tombstones when introspecting relays. IncludeTombstones bool `json:"includeTombstones"` // IncludeOtherChannels will include basic information about other channels // created in the same process as this channel. IncludeOtherChannels bool `json:"includeOtherChannels"` } // RuntimeVersion includes version information about the runtime and // the tchannel library. type RuntimeVersion struct { GoVersion string `json:"goVersion"` LibraryVersion string `json:"tchannelVersion"` } // RuntimeState is a snapshot of the runtime state for a channel. type RuntimeState struct { ID uint32 `json:"id"` ChannelState string `json:"channelState"` // CreatedStack is the stack for how this channel was created. CreatedStack string `json:"createdStack"` // LocalPeer is the local peer information (service name, host-port, etc). LocalPeer LocalPeerInfo `json:"localPeer"` // SubChannels contains information about any subchannels. SubChannels map[string]SubChannelRuntimeState `json:"subChannels"` // RootPeers contains information about all the peers on this channel and their connections. RootPeers map[string]PeerRuntimeState `json:"rootPeers"` // Peers is the list of shared peers for this channel. Peers []SubPeerScore `json:"peers"` // NumConnections is the number of connections stored in the channel. NumConnections int `json:"numConnections"` // Connections is the list of connection IDs in the channel Connections []uint32 ` json:"connections"` // InactiveConnections is the connection state for connections that are not active, // and hence are not reported as part of root peers. InactiveConnections []ConnectionRuntimeState `json:"inactiveConnections"` // OtherChannels is information about any other channels running in this process. OtherChannels map[string][]ChannelInfo `json:"otherChannels,omitEmpty"` // RuntimeVersion is the version information about the runtime and the library. RuntimeVersion RuntimeVersion `json:"runtimeVersion"` } // GoRuntimeStateOptions are the options used when getting Go runtime state. type GoRuntimeStateOptions struct { // IncludeGoStacks will include all goroutine stacks. IncludeGoStacks bool `json:"includeGoStacks"` } // ChannelInfo is the state of other channels in the same process. type ChannelInfo struct { ID uint32 `json:"id"` CreatedStack string `json:"createdStack"` LocalPeer LocalPeerInfo `json:"localPeer"` } // GoRuntimeState is a snapshot of runtime stats from the runtime. type GoRuntimeState struct { MemStats runtime.MemStats `json:"memStats"` NumGoroutines int `json:"numGoRoutines"` NumCPU int `json:"numCPU"` NumCGo int64 `json:"numCGo"` GoStacks []byte `json:"goStacks,omitempty"` } // SubChannelRuntimeState is the runtime state for a subchannel. type SubChannelRuntimeState struct { Service string `json:"service"` Isolated bool `json:"isolated"` // IsolatedPeers is the list of all isolated peers for this channel. IsolatedPeers []SubPeerScore `json:"isolatedPeers,omitempty"` Handler HandlerRuntimeState `json:"handler"` } // HandlerRuntimeState TODO type HandlerRuntimeState struct { Type handlerType `json:"type"` Methods []string `json:"methods,omitempty"` } type handlerType string func (h handlerType) String() string { return string(h) } const ( methodHandler handlerType = "methods" overrideHandler = "overriden" ) // SubPeerScore show the runtime state of a peer with score. type SubPeerScore struct { HostPort string `json:"hostPort"` Score uint64 `json:"score"` } // ConnectionRuntimeState is the runtime state for a single connection. type ConnectionRuntimeState struct { ID uint32 `json:"id"` ConnectionState string `json:"connectionState"` LocalHostPort string `json:"localHostPort"` RemoteHostPort string `json:"remoteHostPort"` OutboundHostPort string `json:"outboundHostPort"` RemotePeer PeerInfo `json:"remotePeer"` InboundExchange ExchangeSetRuntimeState `json:"inboundExchange"` OutboundExchange ExchangeSetRuntimeState `json:"outboundExchange"` Relayer RelayerRuntimeState `json:"relayer"` HealthChecks []bool `json:"healthChecks,omitempty"` LastActivityRead int64 `json:"lastActivityRead"` LastActivityWrite int64 `json:"lastActivityWrite"` SendChQueued int `json:"sendChQueued"` SendChCapacity int `json:"sendChCapacity"` SendBufferUsage int `json:"sendBufferUsage"` SendBufferSize int `json:"sendBufferSize"` } // RelayerRuntimeState is the runtime state for a single relayer. type RelayerRuntimeState struct { Count int `json:"count"` InboundItems RelayItemSetState `json:"inboundItems"` OutboundItems RelayItemSetState `json:"outboundItems"` MaxTimeout time.Duration `json:"maxTimeout"` MaxConnectionTimeout time.Duration `json:"maxConnectionTimeout"` } // ExchangeSetRuntimeState is the runtime state for a message exchange set. type ExchangeSetRuntimeState struct { Name string `json:"name"` Count int `json:"count"` Exchanges map[string]ExchangeRuntimeState `json:"exchanges,omitempty"` } // RelayItemSetState is the runtime state for a list of relay items. type RelayItemSetState struct { Name string `json:"name"` Count int `json:"count"` Items map[string]RelayItemState `json:"items,omitempty"` } // ExchangeRuntimeState is the runtime state for a single message exchange. type ExchangeRuntimeState struct { ID uint32 `json:"id"` MessageType messageType `json:"messageType"` } // RelayItemState is the runtime state for a single relay item. type RelayItemState struct { ID uint32 `json:"id"` RemapID uint32 `json:"remapID"` DestinationConnectionID uint32 `json:"destinationConnectionID"` Tomb bool `json:"tomb"` } // PeerRuntimeState is the runtime state for a single peer. type PeerRuntimeState struct { HostPort string `json:"hostPort"` OutboundConnections []ConnectionRuntimeState `json:"outboundConnections"` InboundConnections []ConnectionRuntimeState `json:"inboundConnections"` ChosenCount uint64 `json:"chosenCount"` SCCount uint32 `json:"scCount"` } // IntrospectState returns the RuntimeState for this channel. // Note: this is purely for debugging and monitoring, and may slow down your Channel. func (ch *Channel) IntrospectState(opts *IntrospectionOptions) *RuntimeState { if opts == nil { opts = &IntrospectionOptions{} } ch.mutable.RLock() state := ch.mutable.state numConns := len(ch.mutable.conns) inactiveConns := make([]*Connection, 0, numConns) connIDs := make([]uint32, 0, numConns) for id, conn := range ch.mutable.conns { connIDs = append(connIDs, id) if !conn.IsActive() { inactiveConns = append(inactiveConns, conn) } } ch.mutable.RUnlock() ch.State() return &RuntimeState{ ID: ch.chID, ChannelState: state.String(), CreatedStack: ch.createdStack, LocalPeer: ch.PeerInfo(), SubChannels: ch.subChannels.IntrospectState(opts), RootPeers: ch.RootPeers().IntrospectState(opts), Peers: ch.Peers().IntrospectList(opts), NumConnections: numConns, Connections: connIDs, InactiveConnections: getConnectionRuntimeState(inactiveConns, opts), OtherChannels: ch.IntrospectOthers(opts), RuntimeVersion: introspectRuntimeVersion(), } } // IntrospectOthers returns the ChannelInfo for all other channels in this process. func (ch *Channel) IntrospectOthers(opts *IntrospectionOptions) map[string][]ChannelInfo { if !opts.IncludeOtherChannels { return nil } channelMap.Lock() defer channelMap.Unlock() states := make(map[string][]ChannelInfo) for svc, channels := range channelMap.existing { channelInfos := make([]ChannelInfo, 0, len(channels)) for _, otherChan := range channels { if ch == otherChan { continue } channelInfos = append(channelInfos, otherChan.ReportInfo(opts)) } states[svc] = channelInfos } return states } // ReportInfo returns ChannelInfo for a channel. func (ch *Channel) ReportInfo(opts *IntrospectionOptions) ChannelInfo { return ChannelInfo{ ID: ch.chID, CreatedStack: ch.createdStack, LocalPeer: ch.PeerInfo(), } } type containsPeerList interface { Copy() map[string]*Peer } func fromPeerList(peers containsPeerList, opts *IntrospectionOptions) map[string]PeerRuntimeState { m := make(map[string]PeerRuntimeState) for _, peer := range peers.Copy() { peerState := peer.IntrospectState(opts) if len(peerState.InboundConnections)+len(peerState.OutboundConnections) > 0 || opts.IncludeEmptyPeers { m[peer.HostPort()] = peerState } } return m } // IntrospectState returns the runtime state of the func (l *RootPeerList) IntrospectState(opts *IntrospectionOptions) map[string]PeerRuntimeState { return fromPeerList(l, opts) } // IntrospectState returns the runtime state of the subchannels. func (subChMap *subChannelMap) IntrospectState(opts *IntrospectionOptions) map[string]SubChannelRuntimeState { m := make(map[string]SubChannelRuntimeState) subChMap.RLock() for k, sc := range subChMap.subchannels { state := SubChannelRuntimeState{ Service: k, Isolated: sc.Isolated(), } if state.Isolated { state.IsolatedPeers = sc.Peers().IntrospectList(opts) } if hmap, ok := sc.handler.(*handlerMap); ok { state.Handler.Type = methodHandler methods := make([]string, 0, len(hmap.handlers)) for k := range hmap.handlers { methods = append(methods, k) } sort.Strings(methods) state.Handler.Methods = methods } else { state.Handler.Type = overrideHandler } m[k] = state } subChMap.RUnlock() return m } func getConnectionRuntimeState(conns []*Connection, opts *IntrospectionOptions) []ConnectionRuntimeState { connStates := make([]ConnectionRuntimeState, len(conns)) for i, conn := range conns { connStates[i] = conn.IntrospectState(opts) } return connStates } // IntrospectState returns the runtime state for this peer. func (p *Peer) IntrospectState(opts *IntrospectionOptions) PeerRuntimeState { p.RLock() defer p.RUnlock() return PeerRuntimeState{ HostPort: p.hostPort, InboundConnections: getConnectionRuntimeState(p.inboundConnections, opts), OutboundConnections: getConnectionRuntimeState(p.outboundConnections, opts), ChosenCount: p.chosenCount.Load(), SCCount: p.scCount, } } // IntrospectState returns the runtime state for this connection. func (c *Connection) IntrospectState(opts *IntrospectionOptions) ConnectionRuntimeState { c.stateMut.RLock() defer c.stateMut.RUnlock() // Ignore errors getting send buffer sizes. sendBufUsage, sendBufSize, _ := c.sendBufSize() // TODO(prashantv): Add total number of health checks, and health check options. state := ConnectionRuntimeState{ ID: c.connID, ConnectionState: c.state.String(), LocalHostPort: c.conn.LocalAddr().String(), RemoteHostPort: c.conn.RemoteAddr().String(), OutboundHostPort: c.outboundHP, RemotePeer: c.remotePeerInfo, InboundExchange: c.inbound.IntrospectState(opts), OutboundExchange: c.outbound.IntrospectState(opts), HealthChecks: c.healthCheckHistory.asBools(), LastActivityRead: c.lastActivityRead.Load(), LastActivityWrite: c.lastActivityWrite.Load(), SendChQueued: len(c.sendCh), SendChCapacity: cap(c.sendCh), SendBufferUsage: sendBufUsage, SendBufferSize: sendBufSize, } if c.relay != nil { state.Relayer = c.relay.IntrospectState(opts) } return state } // IntrospectState returns the runtime state for this relayer. func (r *Relayer) IntrospectState(opts *IntrospectionOptions) RelayerRuntimeState { count := r.inbound.Count() + r.outbound.Count() return RelayerRuntimeState{ Count: count, InboundItems: r.inbound.IntrospectState(opts, "inbound"), OutboundItems: r.outbound.IntrospectState(opts, "outbound"), MaxTimeout: r.maxTimeout, MaxConnectionTimeout: r.maxConnTimeout, } } // IntrospectState returns the runtime state for this relayItems. func (ri *relayItems) IntrospectState(opts *IntrospectionOptions, name string) RelayItemSetState { setState := RelayItemSetState{ Name: name, Count: ri.Count(), } if opts.IncludeExchanges { ri.RLock() defer ri.RUnlock() setState.Items = make(map[string]RelayItemState, len(ri.items)) for k, v := range ri.items { if !opts.IncludeTombstones && v.tomb { continue } state := RelayItemState{ ID: k, RemapID: v.remapID, DestinationConnectionID: v.destination.conn.connID, Tomb: v.tomb, } setState.Items[strconv.Itoa(int(k))] = state } } return setState } // IntrospectState returns the runtime state for this messsage exchange set. func (mexset *messageExchangeSet) IntrospectState(opts *IntrospectionOptions) ExchangeSetRuntimeState { mexset.RLock() setState := ExchangeSetRuntimeState{ Name: mexset.name, Count: len(mexset.exchanges), } if opts != nil && opts.IncludeExchanges { setState.Exchanges = make(map[string]ExchangeRuntimeState, len(mexset.exchanges)) for k, v := range mexset.exchanges { state := ExchangeRuntimeState{ ID: k, MessageType: v.msgType, } setState.Exchanges[strconv.Itoa(int(k))] = state } } mexset.RUnlock() return setState } func getStacks(all bool) []byte { var buf []byte for n := 4096; n < 10*1024*1024; n *= 2 { buf = make([]byte, n) stackLen := runtime.Stack(buf, all) if stackLen < n { return buf[:stackLen] } } // return the first 10MB of stacks if we have more than 10MB. return buf } func (ch *Channel) handleIntrospection(arg3 []byte) interface{} { var opts struct { IntrospectionOptions // (optional) ID of the channel to introspection. If unspecified, uses ch. ChannelID *uint32 `json:"id"` } json.Unmarshal(arg3, &opts) if opts.ChannelID != nil { id := *opts.ChannelID var ok bool ch, ok = findChannelByID(id) if !ok { return map[string]string{"error": fmt.Sprintf(`failed to find channel with "id": %v`, id)} } } return ch.IntrospectState(&opts.IntrospectionOptions) } // IntrospectList returns the list of peers (hostport, score) in this peer list. func (l *PeerList) IntrospectList(opts *IntrospectionOptions) []SubPeerScore { var peers []SubPeerScore l.RLock() for _, ps := range l.peerHeap.peerScores { peers = append(peers, SubPeerScore{ HostPort: ps.Peer.hostPort, Score: ps.score, }) } l.RUnlock() return peers } // IntrospectNumConnections returns the number of connections returns the number // of connections. Note: like other introspection APIs, this is not a stable API. func (ch *Channel) IntrospectNumConnections() int { ch.mutable.RLock() numConns := len(ch.mutable.conns) ch.mutable.RUnlock() return numConns } func handleInternalRuntime(arg3 []byte) interface{} { var opts GoRuntimeStateOptions json.Unmarshal(arg3, &opts) state := GoRuntimeState{ NumGoroutines: runtime.NumGoroutine(), NumCPU: runtime.NumCPU(), NumCGo: runtime.NumCgoCall(), } runtime.ReadMemStats(&state.MemStats) if opts.IncludeGoStacks { state.GoStacks = getStacks(true /* all */) } return state } func introspectRuntimeVersion() RuntimeVersion { return RuntimeVersion{ GoVersion: runtime.Version(), LibraryVersion: VersionInfo, } } // registerInternal registers the following internal handlers which return runtime state: // // _gometa_introspect: TChannel internal state. // _gometa_runtime: Golang runtime stats. func (ch *Channel) createInternalHandlers() *handlerMap { internalHandlers := &handlerMap{} endpoints := []struct { name string handler func([]byte) interface{} }{ {"_gometa_introspect", ch.handleIntrospection}, {"_gometa_runtime", handleInternalRuntime}, } for _, ep := range endpoints { // We need ep in our closure. ep := ep handler := func(ctx context.Context, call *InboundCall) { var arg2, arg3 []byte if err := NewArgReader(call.Arg2Reader()).Read(&arg2); err != nil { return } if err := NewArgReader(call.Arg3Reader()).Read(&arg3); err != nil { return } if err := NewArgWriter(call.Response().Arg2Writer()).Write(nil); err != nil { return } NewArgWriter(call.Response().Arg3Writer()).WriteJSON(ep.handler(arg3)) } h := HandlerFunc(handler) internalHandlers.Register(h, ep.name) // Register under the service name of channel as well (for backwards compatibility). ch.GetSubChannel(ch.PeerInfo().ServiceName).Register(h, ep.name) } return internalHandlers } ================================================ FILE: introspection_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "context" "math" "strconv" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/json" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Purpose of this test is to ensure introspection doesn't cause any panics // and we have coverage of the introspection code. func TestIntrospection(t *testing.T) { opts := testutils.NewOpts(). AddLogFilter("Couldn't find handler", 1). // call with service name fails NoRelay() // "tchannel" service name is not forwarded. testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := testutils.NewClient(t, nil) defer client.Close() ctx, cancel := json.NewContext(time.Second) defer cancel() var resp map[string]interface{} peer := client.Peers().GetOrAdd(ts.HostPort()) err := json.CallPeer(ctx, peer, "tchannel", "_gometa_introspect", map[string]interface{}{ "includeExchanges": true, "includeEmptyPeers": true, "includeTombstones": true, }, &resp) require.NoError(t, err, "Call _gometa_introspect failed") err = json.CallPeer(ctx, peer, ts.ServiceName(), "_gometa_introspect", nil /* arg */, &resp) require.NoError(t, err, "Call _gometa_introspect failed") // Try making the call on any other service name will fail. err = json.CallPeer(ctx, peer, "unknown-service", "_gometa_runtime", map[string]interface{}{ "includeGoStacks": true, }, &resp) require.Error(t, err, "_gometa_introspect should only be registered under tchannel") }) } func TestIntrospectByID(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { client := testutils.NewClient(t, nil) defer client.Close() ctx, cancel := json.NewContext(time.Second) defer cancel() clientID := client.IntrospectState(nil).ID var resp map[string]interface{} peer := client.Peers().GetOrAdd(ts.HostPort()) err := json.CallPeer(ctx, peer, ts.ServiceName(), "_gometa_introspect", map[string]interface{}{ "id": clientID, }, &resp) require.NoError(t, err, "Call _gometa_introspect failed") // Verify that the response matches the channel ID we expected. assert.EqualValues(t, clientID, resp["id"], "unexpected response channel ID") // If use an ID which doesn't exist, we get an error resp = nil err = json.CallPeer(ctx, peer, ts.ServiceName(), "_gometa_introspect", map[string]interface{}{ "id": math.MaxUint32, }, &resp) require.NoError(t, err, "Call _gometa_introspect failed") assert.EqualValues(t, `failed to find channel with "id": `+strconv.Itoa(math.MaxUint32), resp["error"]) }) } func TestIntrospectClosedConn(t *testing.T) { // Disable the relay, since the relay does not maintain a 1:1 mapping betewen // incoming connections vs outgoing connections. opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { blockEcho := make(chan struct{}) gotEcho := make(chan struct{}) testutils.RegisterEcho(ts.Server(), func() { close(gotEcho) <-blockEcho }) ctx, cancel := NewContext(time.Second) defer cancel() assert.Equal(t, 0, ts.Server().IntrospectNumConnections(), "Expected no connection on new server") // Make sure that a closed connection will reduce NumConnections. client := ts.NewClient(nil) require.NoError(t, client.Ping(ctx, ts.HostPort()), "Ping from new client failed") assert.Equal(t, 1, ts.Server().IntrospectNumConnections(), "Number of connections expected to increase") go testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) // The state will change to "closeStarted", but be blocked due to the blocked // echo call. <-gotEcho client.Close() introspected := client.IntrospectState(nil) assert.Len(t, introspected.Connections, 1, "Expected single connection due to blocked call") assert.Len(t, introspected.InactiveConnections, 1, "Expected inactive connection due to blocked call") close(blockEcho) require.True(t, testutils.WaitFor(100*time.Millisecond, func() bool { return ts.Server().IntrospectNumConnections() == 0 }), "Closed connection did not get removed, num connections is %v", ts.Server().IntrospectNumConnections()) for i := 0; i < 10; i++ { client := ts.NewClient(nil) defer client.Close() require.NoError(t, client.Ping(ctx, ts.HostPort()), "Ping from new client failed") assert.Equal(t, 1, client.IntrospectNumConnections(), "Client should have single connection") assert.Equal(t, i+1, ts.Server().IntrospectNumConnections(), "Incorrect number of server connections") } }) } func TestIntrospectionNotBlocked(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { subCh := ts.Server().GetSubChannel("tchannel") subCh.SetHandler(HandlerFunc(func(ctx context.Context, inbound *InboundCall) { panic("should not be called") })) // Ensure that tchannel is also relayed if ts.HasRelay() { ts.RelayHost().Add("tchannel", ts.Server().PeerInfo().HostPort) } ctx, cancel := NewContext(time.Second) defer cancel() client := ts.NewClient(nil) peer := client.Peers().GetOrAdd(ts.HostPort()) // Ensure that SetHandler doesn't block introspection. var resp interface{} err := json.CallPeer(Wrap(ctx), peer, "tchannel", "_gometa_runtime", nil, &resp) require.NoError(t, err, "Call _gometa_runtime failed") }) } ================================================ FILE: json/call.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package json import ( "fmt" "github.com/uber/tchannel-go" "golang.org/x/net/context" ) // ErrApplication is an application error which contains the object returned from the other side. type ErrApplication map[string]interface{} func (e ErrApplication) Error() string { return fmt.Sprintf("JSON call failed: %v", map[string]interface{}(e)) } // Client is used to make JSON calls to other services. type Client struct { ch *tchannel.Channel targetService string hostPort string } // ClientOptions are options used when creating a client. type ClientOptions struct { HostPort string } // NewClient returns a json.Client used to make outbound JSON calls. func NewClient(ch *tchannel.Channel, targetService string, opts *ClientOptions) *Client { client := &Client{ ch: ch, targetService: targetService, } if opts != nil && opts.HostPort != "" { client.hostPort = opts.HostPort } return client } func makeCall(call *tchannel.OutboundCall, headers, arg3In, respHeaders, arg3Out, errorOut interface{}) (bool, string, error) { if mapHeaders, ok := headers.(map[string]string); ok { headers = tchannel.InjectOutboundSpan(call.Response(), mapHeaders) } if err := tchannel.NewArgWriter(call.Arg2Writer()).WriteJSON(headers); err != nil { return false, "arg2 write failed", err } if err := tchannel.NewArgWriter(call.Arg3Writer()).WriteJSON(arg3In); err != nil { return false, "arg3 write failed", err } // Call Arg2Reader before checking application error. if err := tchannel.NewArgReader(call.Response().Arg2Reader()).ReadJSON(respHeaders); err != nil { return false, "arg2 read failed", err } // If this is an error response, read the response into a map and return a jsonCallErr. if call.Response().ApplicationError() { if err := tchannel.NewArgReader(call.Response().Arg3Reader()).ReadJSON(errorOut); err != nil { return false, "arg3 read error failed", err } return false, "", nil } if err := tchannel.NewArgReader(call.Response().Arg3Reader()).ReadJSON(arg3Out); err != nil { return false, "arg3 read failed", err } return true, "", nil } func (c *Client) startCall(ctx context.Context, method string, callOptions *tchannel.CallOptions) (*tchannel.OutboundCall, error) { if c.hostPort != "" { return c.ch.BeginCall(ctx, c.hostPort, c.targetService, method, callOptions) } return c.ch.GetSubChannel(c.targetService).BeginCall(ctx, method, callOptions) } // Call makes a JSON call, with retries. func (c *Client) Call(ctx Context, method string, arg, resp interface{}) error { var ( headers = ctx.Headers() respHeaders map[string]string respErr ErrApplication errAt string isOK bool ) err := c.ch.RunWithRetry(ctx, func(ctx context.Context, rs *tchannel.RequestState) error { respHeaders, respErr, isOK = nil, nil, false errAt = "connect" call, err := c.startCall(ctx, method, &tchannel.CallOptions{ Format: tchannel.JSON, RequestState: rs, }) if err != nil { return err } isOK, errAt, err = makeCall(call, headers, arg, &respHeaders, resp, &respErr) return err }) if err != nil { // TODO: Don't lose the error type here. return fmt.Errorf("%s: %v", errAt, err) } if !isOK { return respErr } return nil } // TODO(prashantv): Clean up json.Call* interfaces. func wrapCall(ctx Context, call *tchannel.OutboundCall, method string, arg, resp interface{}) error { var respHeaders map[string]string var respErr ErrApplication isOK, errAt, err := makeCall(call, ctx.Headers(), arg, &respHeaders, resp, &respErr) if err != nil { return fmt.Errorf("%s: %v", errAt, err) } if !isOK { return respErr } ctx.SetResponseHeaders(respHeaders) return nil } // CallPeer makes a JSON call using the given peer. func CallPeer(ctx Context, peer *tchannel.Peer, serviceName, method string, arg, resp interface{}) error { call, err := peer.BeginCall(ctx, serviceName, method, &tchannel.CallOptions{Format: tchannel.JSON}) if err != nil { return err } return wrapCall(ctx, call, method, arg, resp) } // CallSC makes a JSON call using the given subchannel. func CallSC(ctx Context, sc *tchannel.SubChannel, method string, arg, resp interface{}) error { call, err := sc.BeginCall(ctx, method, &tchannel.CallOptions{Format: tchannel.JSON}) if err != nil { return err } return wrapCall(ctx, call, method, arg, resp) } ================================================ FILE: json/context.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package json import ( "time" "github.com/uber/tchannel-go" "golang.org/x/net/context" ) // Context is a JSON Context which contains request and response headers. type Context tchannel.ContextWithHeaders // NewContext returns a Context that can be used to make JSON calls. func NewContext(timeout time.Duration) (Context, context.CancelFunc) { ctx, cancel := tchannel.NewContext(timeout) return tchannel.WrapWithHeaders(ctx, nil), cancel } // Wrap returns a JSON Context that wraps around a Context. func Wrap(ctx context.Context) Context { return tchannel.WrapWithHeaders(ctx, nil) } // WithHeaders returns a Context that can be used to make a call with request headers. func WithHeaders(ctx context.Context, headers map[string]string) Context { return tchannel.WrapWithHeaders(ctx, headers) } ================================================ FILE: json/handler.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package json import ( "fmt" "reflect" "github.com/uber/tchannel-go" "github.com/opentracing/opentracing-go" "golang.org/x/net/context" ) var ( typeOfError = reflect.TypeOf((*error)(nil)).Elem() typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() ) // Handlers is the map from method names to handlers. type Handlers map[string]interface{} // verifyHandler ensures that the given t is a function with the following signature: // func(json.Context, *ArgType)(*ResType, error) func verifyHandler(t reflect.Type) error { if t.NumIn() != 2 || t.NumOut() != 2 { return fmt.Errorf("handler should be of format func(json.Context, *ArgType) (*ResType, error)") } isStructPtr := func(t reflect.Type) bool { return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct } isMap := func(t reflect.Type) bool { return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String } validateArgRes := func(t reflect.Type, name string) error { if !isStructPtr(t) && !isMap(t) { return fmt.Errorf("%v should be a pointer to a struct, or a map[string]interface{}", name) } return nil } if t.In(0) != typeOfContext { return fmt.Errorf("arg0 should be of type json.Context") } if err := validateArgRes(t.In(1), "second argument"); err != nil { return err } if err := validateArgRes(t.Out(0), "first return value"); err != nil { return err } if !t.Out(1).AssignableTo(typeOfError) { return fmt.Errorf("second return value should be an error") } return nil } type handler struct { handler reflect.Value argType reflect.Type isArgMap bool tracer func() opentracing.Tracer } func toHandler(f interface{}) (*handler, error) { hV := reflect.ValueOf(f) if err := verifyHandler(hV.Type()); err != nil { return nil, err } argType := hV.Type().In(1) return &handler{handler: hV, argType: argType, isArgMap: argType.Kind() == reflect.Map}, nil } // Register registers the specified methods specified as a map from method name to the // JSON handler function. The handler functions should have the following signature: // func(context.Context, *ArgType)(*ResType, error) func Register(registrar tchannel.Registrar, funcs Handlers, onError func(context.Context, error)) error { handlers := make(map[string]*handler) handler := tchannel.HandlerFunc(func(ctx context.Context, call *tchannel.InboundCall) { h, ok := handlers[string(call.Method())] if !ok { onError(ctx, fmt.Errorf("call for unregistered method: %s", call.Method())) return } if err := h.Handle(ctx, call); err != nil { onError(ctx, err) } }) for m, f := range funcs { h, err := toHandler(f) if err != nil { return fmt.Errorf("%v cannot be used as a handler: %v", m, err) } h.tracer = func() opentracing.Tracer { return tchannel.TracerFromRegistrar(registrar) } handlers[m] = h registrar.Register(handler, m) } return nil } // Handle deserializes the JSON arguments and calls the underlying handler. func (h *handler) Handle(tctx context.Context, call *tchannel.InboundCall) error { var headers map[string]string if err := tchannel.NewArgReader(call.Arg2Reader()).ReadJSON(&headers); err != nil { return fmt.Errorf("arg2 read failed: %v", err) } tctx = tchannel.ExtractInboundSpan(tctx, call, headers, h.tracer()) ctx := WithHeaders(tctx, headers) var arg3 reflect.Value var callArg reflect.Value if h.isArgMap { arg3 = reflect.New(h.argType) // New returns a pointer, but the method accepts the map directly. callArg = arg3.Elem() } else { arg3 = reflect.New(h.argType.Elem()) callArg = arg3 } if err := tchannel.NewArgReader(call.Arg3Reader()).ReadJSON(arg3.Interface()); err != nil { return fmt.Errorf("arg3 read failed: %v", err) } args := []reflect.Value{reflect.ValueOf(ctx), callArg} results := h.handler.Call(args) res := results[0].Interface() err := results[1].Interface() // If an error was returned, we create an error arg3 to respond with. if err != nil { // TODO(prashantv): More consistent error handling between json/raw/thrift.. if serr, ok := err.(tchannel.SystemError); ok { return call.Response().SendSystemError(serr) } call.Response().SetApplicationError() // TODO(prashant): Allow client to customize the error in more ways. res = struct { Type string `json:"type"` Message string `json:"message"` }{ Type: "error", Message: err.(error).Error(), } } if err := tchannel.NewArgWriter(call.Response().Arg2Writer()).WriteJSON(ctx.ResponseHeaders()); err != nil { return err } return tchannel.NewArgWriter(call.Response().Arg3Writer()).WriteJSON(res) } ================================================ FILE: json/json_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package json import ( "fmt" "testing" "time" "github.com/uber/tchannel-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) // ForwardArgs are the arguments specifying who to forward to (and the message to forward). type ForwardArgs struct { HeaderVal string Service string Method string NextForward *ForwardArgs } // Res is the final result. type Res struct { Result string } type testHandler struct { calls []string callers []string peer *tchannel.Peer t *testing.T } func (h *testHandler) forward(ctx Context, args *ForwardArgs) (*Res, error) { headerVal := ctx.Headers()["hdr"] ctx.SetResponseHeaders(map[string]string{"hdr": headerVal + "-resp"}) h.calls = append(h.calls, "forward-"+headerVal) h.callers = append(h.callers, tchannel.CurrentCall(ctx).CallerName()) if args.HeaderVal != "" { ctx = WithHeaders(ctx, map[string]string{"hdr": args.HeaderVal}) } res := &Res{} if args.Method == "forward" { if err := CallPeer(ctx, h.peer, args.Service, args.Method, args.NextForward, res); err != nil { h.t.Errorf("forward->forward Call failed: %v", err) return nil, err } assert.Equal(h.t, map[string]string{"hdr": args.HeaderVal + "-resp"}, ctx.ResponseHeaders()) return res, nil } if err := CallPeer(ctx, h.peer, args.Service, args.Method, nil, res); err != nil { h.t.Errorf("forward->%v Call failed: %v", args.Method, err) return nil, err } return res, nil } func (h *testHandler) leaf(ctx Context, _ *struct{}) (*Res, error) { headerVal := ctx.Headers()["hdr"] h.calls = append(h.calls, "leaf-"+headerVal) h.callers = append(h.callers, tchannel.CurrentCall(ctx).CallerName()) return &Res{"leaf called!"}, nil } func (h *testHandler) onError(ctx context.Context, err error) { h.t.Errorf("onError(%v)", err) } func TestForwardChain(t *testing.T) { servers := map[string]*struct { channel *tchannel.Channel handler *testHandler otherPeer string }{ "serv1": {otherPeer: "serv2"}, "serv2": {otherPeer: "serv3"}, "serv3": {otherPeer: "serv1"}, } // We want the following call graph: // serv1.forward // -> (1) serv2.forward // -> (2) serv3.forward // -> (3) serv1.forward // -> (4) serv2.forward // .... // -> (11) serv3.leaf rootArg := &ForwardArgs{} curArg := rootArg for i := 1; i <= 10; i++ { service := fmt.Sprintf("serv%v", (i%3)+1) curArg.Method = "forward" curArg.HeaderVal = fmt.Sprint(i) curArg.Service = service curArg.NextForward = &ForwardArgs{} curArg = curArg.NextForward } curArg.Service = "serv3" curArg.HeaderVal = "11" curArg.Method = "leaf" expectedCalls := map[string]struct { calls []string callers []string }{ "serv1": { calls: []string{"forward-initial", "forward-3", "forward-6", "forward-9"}, callers: []string{"serv3", "serv3", "serv3", "serv3"}, }, "serv2": { calls: []string{"forward-1", "forward-4", "forward-7", "forward-10"}, callers: []string{"serv1", "serv1", "serv1", "serv1"}, }, "serv3": { calls: []string{"forward-2", "forward-5", "forward-8", "leaf-11"}, callers: []string{"serv2", "serv2", "serv2", "serv2"}, }, } // Use the above data to setup the test and ensure the calls are made as expected. for name, s := range servers { var err error s.channel, err = tchannel.NewChannel(name, nil) require.NoError(t, err) s.handler = &testHandler{t: t} require.NoError(t, Register(s.channel, Handlers{ "forward": s.handler.forward, "leaf": s.handler.leaf, }, s.handler.onError)) require.NoError(t, s.channel.ListenAndServe("127.0.0.1:0")) } for _, s := range servers { s.handler.peer = s.channel.Peers().Add(servers[s.otherPeer].channel.PeerInfo().HostPort) } ctx, cancel := NewContext(time.Second) defer cancel() ctx = WithHeaders(ctx, map[string]string{"hdr": "initial"}) assert.Nil(t, tchannel.CurrentCall(ctx)) sc := servers["serv3"].channel.GetSubChannel("serv1") resp := &Res{} if assert.NoError(t, CallSC(ctx, sc, "forward", rootArg, resp)) { assert.Equal(t, "leaf called!", resp.Result) for s, expected := range expectedCalls { assert.Equal(t, expected.calls, servers[s].handler.calls, "wrong calls for %v", s) assert.Equal(t, expected.callers, servers[s].handler.callers, "wrong callers for %v", s) } } } func TestHeadersForwarded(t *testing.T) { ch, err := tchannel.NewChannel("svc", nil) require.NoError(t, err) handler := &testHandler{t: t} require.NoError(t, Register(ch, Handlers{ "forward": handler.forward, "leaf": handler.leaf, }, handler.onError)) assert.NoError(t, ch.ListenAndServe("127.0.0.1:0")) rootArg := &ForwardArgs{ Service: "svc", Method: "leaf", HeaderVal: "", } ctx, cancel := NewContext(time.Second) defer cancel() ctx = WithHeaders(ctx, map[string]string{"hdr": "copy"}) assert.Nil(t, tchannel.CurrentCall(ctx)) resp := &Res{} handler.peer = ch.Peers().Add(ch.PeerInfo().HostPort) if assert.NoError(t, CallPeer(ctx, handler.peer, "svc", "forward", rootArg, resp)) { // Verify that the header is copied when ctx is not changed. assert.Equal(t, handler.calls, []string{"forward-copy", "leaf-copy"}) } } func TestEmptyRequestHeader(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() ch, err := tchannel.NewChannel("server", nil) require.NoError(t, err) require.NoError(t, ch.ListenAndServe("127.0.0.1:0")) handler := func(ctx Context, _ *struct{}) (*struct{}, error) { assert.Equal(t, map[string]string(nil), ctx.Headers()) return nil, nil } onError := func(ctx context.Context, err error) { t.Errorf("onError: %v", err) } require.NoError(t, Register(ch, Handlers{"handle": handler}, onError)) call, err := ch.BeginCall(ctx, ch.PeerInfo().HostPort, "server", "handle", &tchannel.CallOptions{ Format: tchannel.JSON, }) require.NoError(t, err) require.NoError(t, tchannel.NewArgWriter(call.Arg2Writer()).Write(nil)) require.NoError(t, tchannel.NewArgWriter(call.Arg3Writer()).WriteJSON(nil)) resp := call.Response() var data interface{} require.NoError(t, tchannel.NewArgReader(resp.Arg2Reader()).ReadJSON(&data)) require.NoError(t, tchannel.NewArgReader(resp.Arg3Reader()).ReadJSON(&data)) } func TestMapInputOutput(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() ch, err := tchannel.NewChannel("server", nil) require.NoError(t, err) require.NoError(t, ch.ListenAndServe("127.0.0.1:0")) handler := func(ctx Context, args map[string]interface{}) (map[string]interface{}, error) { return args, nil } onError := func(ctx context.Context, err error) { t.Errorf("onError: %v", err) } require.NoError(t, Register(ch, Handlers{"handle": handler}, onError)) call, err := ch.BeginCall(ctx, ch.PeerInfo().HostPort, "server", "handle", &tchannel.CallOptions{ Format: tchannel.JSON, }) require.NoError(t, err) arg := map[string]interface{}{ "v1": "value1", "v2": 2.0, "v3": map[string]interface{}{"k": "v", "k2": "v2"}, } require.NoError(t, tchannel.NewArgWriter(call.Arg2Writer()).Write(nil)) require.NoError(t, tchannel.NewArgWriter(call.Arg3Writer()).WriteJSON(arg)) resp := call.Response() var data interface{} require.NoError(t, tchannel.NewArgReader(resp.Arg2Reader()).ReadJSON(&data)) require.NoError(t, tchannel.NewArgReader(resp.Arg3Reader()).ReadJSON(&data)) assert.Equal(t, arg, data.(map[string]interface{}), "result does not match arg") } ================================================ FILE: json/retry_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package json import ( "strings" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRetryJSONCall(t *testing.T) { ch := testutils.NewServer(t, nil) ch.Peers().Add(ch.PeerInfo().HostPort) count := 0 handler := func(ctx Context, req map[string]string) (map[string]string, error) { count++ if count > 4 { return req, nil } return nil, tchannel.ErrServerBusy } Register(ch, Handlers{"test": handler}, nil) ctx, cancel := NewContext(time.Second) defer cancel() client := NewClient(ch, ch.ServiceName(), nil) var res map[string]string err := client.Call(ctx, "test", nil, &res) assert.NoError(t, err, "Call should succeed") assert.Equal(t, 5, count, "Handler should have been invoked 5 times") } func TestRetryJSONNoConnect(t *testing.T) { ch := testutils.NewClient(t, nil) ch.Peers().Add("0.0.0.0:0") ctx, cancel := NewContext(time.Second) defer cancel() var res map[string]interface{} client := NewClient(ch, ch.ServiceName(), nil) err := client.Call(ctx, "test", nil, &res) require.Error(t, err, "Call should fail") assert.True(t, strings.HasPrefix(err.Error(), "connect: "), "Error does not contain expected prefix: %v", err.Error()) } ================================================ FILE: json/tracing_test.go ================================================ package json_test import ( "testing" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/json" . "github.com/uber/tchannel-go/testutils/testtracing" "golang.org/x/net/context" ) // JSONHandler tests tracing over JSON encoding type JSONHandler struct { TraceHandler t *testing.T } func (h *JSONHandler) firstCall(ctx context.Context, req *TracingRequest) (*TracingResponse, error) { jctx := json.Wrap(ctx) response := new(TracingResponse) peer := h.Ch.Peers().GetOrAdd(h.Ch.PeerInfo().HostPort) if err := json.CallPeer(jctx, peer, h.Ch.PeerInfo().ServiceName, "call", req, response); err != nil { return nil, err } return response, nil } func (h *JSONHandler) callJSON(ctx json.Context, req *TracingRequest) (*TracingResponse, error) { return h.HandleCall(ctx, req, func(ctx context.Context, req *TracingRequest) (*TracingResponse, error) { jctx := ctx.(json.Context) peer := h.Ch.Peers().GetOrAdd(h.Ch.PeerInfo().HostPort) childResp := new(TracingResponse) if err := json.CallPeer(jctx, peer, h.Ch.PeerInfo().ServiceName, "call", req, childResp); err != nil { return nil, err } return childResp, nil }) } func (h *JSONHandler) onError(ctx context.Context, err error) { h.t.Errorf("onError %v", err) } func TestJSONTracingPropagation(t *testing.T) { suite := &PropagationTestSuite{ Encoding: EncodingInfo{Format: tchannel.JSON, HeadersSupported: true}, Register: func(t *testing.T, ch *tchannel.Channel) TracingCall { handler := &JSONHandler{TraceHandler: TraceHandler{Ch: ch}, t: t} json.Register(ch, json.Handlers{"call": handler.callJSON}, handler.onError) return handler.firstCall }, TestCases: map[TracerType][]PropagationTestCase{ Noop: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: "", ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: "", ExpectedSpanCount: 0}, }, Mock: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 6}, }, Jaeger: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 6}, }, }, } suite.Run(t) } ================================================ FILE: largereq_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "bytes" "log" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/require" ) func TestLargeRequest(t *testing.T) { CheckStress(t) const ( KB = 1024 MB = 1024 * KB GB = 1024 * MB maxRequestSize = 1 * GB ) WithVerifiedServer(t, nil, func(serverCh *Channel, hostPort string) { serverCh.Register(raw.Wrap(newTestHandler(t)), "echo") for reqSize := 2; reqSize <= maxRequestSize; reqSize *= 2 { log.Printf("reqSize = %v", reqSize) arg3 := testutils.RandBytes(reqSize) arg2 := testutils.RandBytes(reqSize / 2) clientCh := testutils.NewClient(t, nil) ctx, cancel := NewContext(time.Second * 30) rArg2, rArg3, _, err := raw.Call(ctx, clientCh, hostPort, serverCh.PeerInfo().ServiceName, "echo", arg2, arg3) require.NoError(t, err, "Call failed") if !bytes.Equal(arg2, rArg2) { t.Errorf("echo arg2 mismatch") } if !bytes.Equal(arg3, rArg3) { t.Errorf("echo arg3 mismatch") } cancel() } }) } ================================================ FILE: localip.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "errors" "net" ) // scoreAddr scores how likely the given addr is to be a remote address and returns the // IP to use when listening. Any address which receives a negative score should not be used. // Scores are calculated as: // -1 for any unknown IP addreseses. // +300 for IPv4 addresses // +100 for non-local addresses, extra +100 for "up" interaces. func scoreAddr(iface net.Interface, addr net.Addr) (int, net.IP) { var ip net.IP if netAddr, ok := addr.(*net.IPNet); ok { ip = netAddr.IP } else if netIP, ok := addr.(*net.IPAddr); ok { ip = netIP.IP } else { return -1, nil } var score int if ip.To4() != nil { score += 300 } if iface.Flags&net.FlagLoopback == 0 && !ip.IsLoopback() { score += 100 if iface.Flags&net.FlagUp != 0 { score += 100 } } if isLocalMacAddr(iface.HardwareAddr) { score -= 50 } return score, ip } func listenIP(interfaces []net.Interface) (net.IP, error) { bestScore := -1 var bestIP net.IP // Select the highest scoring IP as the best IP. for _, iface := range interfaces { addrs, err := iface.Addrs() if err != nil { // Skip this interface if there is an error. continue } for _, addr := range addrs { score, ip := scoreAddr(iface, addr) if score > bestScore { bestScore = score bestIP = ip } } } if bestScore == -1 { return nil, errors.New("no addresses to listen on") } return bestIP, nil } // ListenIP returns the IP to bind to in Listen. It tries to find an IP that can be used // by other machines to reach this machine. func ListenIP() (net.IP, error) { interfaces, err := net.Interfaces() if err != nil { return nil, err } return listenIP(interfaces) } func mustParseMAC(s string) net.HardwareAddr { addr, err := net.ParseMAC(s) if err != nil { panic(err) } return addr } // If the first octet's second least-significant-bit is set, then it's local. // https://en.wikipedia.org/wiki/MAC_address#Universal_vs._local func isLocalMacAddr(addr net.HardwareAddr) bool { if len(addr) == 0 { return false } return addr[0]&2 == 2 } ================================================ FILE: localip_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "net" "testing" "github.com/stretchr/testify/assert" ) func TestScoreAddr(t *testing.T) { ipv4 := net.ParseIP("10.0.1.2") ipv6 := net.ParseIP("2001:db8:a0b:12f0::1") tests := []struct { msg string iface net.Interface addr net.Addr want int wantIP net.IP }{ { msg: "non-local up ipv4 IPNet address", iface: net.Interface{Flags: net.FlagUp}, addr: &net.IPNet{IP: ipv4}, want: 500, wantIP: ipv4, }, { msg: "non-local up ipv4 IPAddr address", iface: net.Interface{Flags: net.FlagUp}, addr: &net.IPAddr{IP: ipv4}, want: 500, wantIP: ipv4, }, { msg: "non-local up ipv4 IPAddr address, docker interface", iface: net.Interface{ Flags: net.FlagUp, HardwareAddr: mustParseMAC("02:42:ac:11:56:af"), }, addr: &net.IPNet{IP: ipv4}, want: 450, wantIP: ipv4, }, { msg: "non-local up ipv4 address, local MAC address", iface: net.Interface{ Flags: net.FlagUp, HardwareAddr: mustParseMAC("02:42:9c:52:fc:86"), }, addr: &net.IPNet{IP: ipv4}, want: 450, wantIP: ipv4, }, { msg: "non-local down ipv4 address", iface: net.Interface{}, addr: &net.IPNet{IP: ipv4}, want: 400, wantIP: ipv4, }, { msg: "non-local down ipv6 address", iface: net.Interface{}, addr: &net.IPAddr{IP: ipv6}, want: 100, wantIP: ipv6, }, { msg: "unknown address type", iface: net.Interface{}, addr: &net.UnixAddr{Name: "/tmp/socket"}, want: -1, }, } for _, tt := range tests { gotScore, gotIP := scoreAddr(tt.iface, tt.addr) assert.Equal(t, tt.want, gotScore, tt.msg) assert.Equal(t, tt.wantIP, gotIP, tt.msg) } } ================================================ FILE: logger.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "io" "time" ) import ( "os" ) // Logger provides an abstract interface for logging from TChannel. // Applications can provide their own implementation of this interface to adapt // TChannel logging to whatever logging library they prefer (stdlib log, // logrus, go-logging, etc). The SimpleLogger adapts to the standard go log // package. type Logger interface { // Enabled returns whether the given level is enabled. Enabled(level LogLevel) bool // Fatal logs a message, then exits with os.Exit(1). Fatal(msg string) // Error logs a message at error priority. Error(msg string) // Warn logs a message at warning priority. Warn(msg string) // Infof logs a message at info priority. Infof(msg string, args ...interface{}) // Info logs a message at info priority. Info(msg string) // Debugf logs a message at debug priority. Debugf(msg string, args ...interface{}) // Debug logs a message at debug priority. Debug(msg string) // Fields returns the fields that this logger contains. Fields() LogFields // WithFields returns a logger with the current logger's fields and fields. WithFields(fields ...LogField) Logger } // LogField is a single field of additional information passed to the logger. type LogField struct { Key string Value interface{} } // ErrField wraps an error string as a LogField named "error" func ErrField(err error) LogField { return LogField{"error", err.Error()} } // LogFields is a list of LogFields used to pass additional information to the logger. type LogFields []LogField // NullLogger is a logger that emits nowhere var NullLogger Logger = nullLogger{} type nullLogger struct { fields LogFields } func (nullLogger) Enabled(_ LogLevel) bool { return false } func (nullLogger) Fatal(msg string) { os.Exit(1) } func (nullLogger) Error(msg string) {} func (nullLogger) Warn(msg string) {} func (nullLogger) Infof(msg string, args ...interface{}) {} func (nullLogger) Info(msg string) {} func (nullLogger) Debugf(msg string, args ...interface{}) {} func (nullLogger) Debug(msg string) {} func (l nullLogger) Fields() LogFields { return l.fields } func (l nullLogger) WithFields(fields ...LogField) Logger { newFields := make([]LogField, len(l.Fields())+len(fields)) n := copy(newFields, l.Fields()) copy(newFields[n:], fields) return nullLogger{newFields} } // SimpleLogger prints logging information to standard out. var SimpleLogger = NewLogger(os.Stdout) type writerLogger struct { writer io.Writer fields LogFields } const writerLoggerStamp = "15:04:05.000000" // NewLogger returns a Logger that writes to the given writer. func NewLogger(writer io.Writer, fields ...LogField) Logger { return &writerLogger{writer, fields} } func (l writerLogger) Fatal(msg string) { l.printfn("F", msg) os.Exit(1) } func (l writerLogger) Enabled(_ LogLevel) bool { return true } func (l writerLogger) Error(msg string) { l.printfn("E", msg) } func (l writerLogger) Warn(msg string) { l.printfn("W", msg) } func (l writerLogger) Infof(msg string, args ...interface{}) { l.printfn("I", msg, args...) } func (l writerLogger) Info(msg string) { l.printfn("I", msg) } func (l writerLogger) Debugf(msg string, args ...interface{}) { l.printfn("D", msg, args...) } func (l writerLogger) Debug(msg string) { l.printfn("D", msg) } func (l writerLogger) printfn(prefix, msg string, args ...interface{}) { fmt.Fprintf(l.writer, "%s [%s] %s tags: %v\n", time.Now().Format(writerLoggerStamp), prefix, fmt.Sprintf(msg, args...), l.fields) } func (l writerLogger) Fields() LogFields { return l.fields } func (l writerLogger) WithFields(newFields ...LogField) Logger { existingFields := l.Fields() fields := make(LogFields, 0, len(existingFields)+1) fields = append(fields, existingFields...) fields = append(fields, newFields...) return writerLogger{l.writer, fields} } // LogLevel is the level of logging used by LevelLogger. type LogLevel int // The minimum level that will be logged. e.g. LogLevelError only logs errors and fatals. const ( LogLevelAll LogLevel = iota LogLevelDebug LogLevelInfo LogLevelWarn LogLevelError LogLevelFatal ) type levelLogger struct { logger Logger level LogLevel } // NewLevelLogger returns a logger that only logs messages with a minimum of level. func NewLevelLogger(logger Logger, level LogLevel) Logger { return levelLogger{logger, level} } func (l levelLogger) Enabled(level LogLevel) bool { return l.level <= level } func (l levelLogger) Fatal(msg string) { if l.level <= LogLevelFatal { l.logger.Fatal(msg) } } func (l levelLogger) Error(msg string) { if l.level <= LogLevelError { l.logger.Error(msg) } } func (l levelLogger) Warn(msg string) { if l.level <= LogLevelWarn { l.logger.Warn(msg) } } func (l levelLogger) Infof(msg string, args ...interface{}) { if l.level <= LogLevelInfo { l.logger.Infof(msg, args...) } } func (l levelLogger) Info(msg string) { if l.level <= LogLevelInfo { l.logger.Info(msg) } } func (l levelLogger) Debugf(msg string, args ...interface{}) { if l.level <= LogLevelDebug { l.logger.Debugf(msg, args...) } } func (l levelLogger) Debug(msg string) { if l.level <= LogLevelDebug { l.logger.Debug(msg) } } func (l levelLogger) Fields() LogFields { return l.logger.Fields() } func (l levelLogger) WithFields(fields ...LogField) Logger { return levelLogger{ logger: l.logger.WithFields(fields...), level: l.level, } } ================================================ FILE: logger_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "bytes" "errors" "testing" . "github.com/uber/tchannel-go" "github.com/stretchr/testify/assert" ) func field(k string, v interface{}) LogField { return LogField{Key: k, Value: v} } func TestErrField(t *testing.T) { assert.Equal(t, field("error", "foo"), ErrField(errors.New("foo"))) } func TestWriterLogger(t *testing.T) { var buf bytes.Buffer var bufLogger = NewLogger(&buf) debugf := func(logger Logger, msg string, args ...interface{}) { logger.Debugf(msg, args...) } infof := func(logger Logger, msg string, args ...interface{}) { logger.Infof(msg, args...) } levels := []struct { levelFunc func(logger Logger, msg string, args ...interface{}) levelPrefix string }{ {debugf, "D"}, {infof, "I"}, } for _, level := range levels { tagLogger1 := bufLogger.WithFields(field("key1", "value1")) tagLogger2 := bufLogger.WithFields(field("key2", "value2"), field("key3", "value3")) verifyMsgAndPrefix := func(logger Logger) { buf.Reset() level.levelFunc(logger, "mes%v", "sage") out := buf.String() assert.Contains(t, out, "message") assert.Contains(t, out, "["+level.levelPrefix+"]") } verifyMsgAndPrefix(bufLogger) verifyMsgAndPrefix(tagLogger1) assert.Contains(t, buf.String(), "{key1 value1}") assert.NotContains(t, buf.String(), "{key2 value2}") assert.NotContains(t, buf.String(), "{key3 value3}") verifyMsgAndPrefix(tagLogger2) assert.Contains(t, buf.String(), "{key2 value2}") assert.Contains(t, buf.String(), "{key3 value3}") assert.NotContains(t, buf.String(), "{key1 value1}") } } func TestWriterLoggerNoSubstitution(t *testing.T) { var buf bytes.Buffer var bufLogger = NewLogger(&buf) logDebug := func(logger Logger, msg string) { logger.Debug(msg) } logInfo := func(logger Logger, msg string) { logger.Info(msg) } logWarn := func(logger Logger, msg string) { logger.Warn(msg) } logError := func(logger Logger, msg string) { logger.Error(msg) } levels := []struct { levelFunc func(logger Logger, msg string) levelPrefix string }{ {logDebug, "D"}, {logInfo, "I"}, {logWarn, "W"}, {logError, "E"}, } for _, level := range levels { tagLogger1 := bufLogger.WithFields(field("key1", "value1")) tagLogger2 := bufLogger.WithFields(field("key2", "value2"), field("key3", "value3")) verifyMsgAndPrefix := func(logger Logger) { buf.Reset() level.levelFunc(logger, "test-msg") out := buf.String() assert.Contains(t, out, "test-msg") assert.Contains(t, out, "["+level.levelPrefix+"]") } verifyMsgAndPrefix(bufLogger) verifyMsgAndPrefix(tagLogger1) assert.Contains(t, buf.String(), "{key1 value1}") assert.NotContains(t, buf.String(), "{key2 value2}") assert.NotContains(t, buf.String(), "{key3 value3}") verifyMsgAndPrefix(tagLogger2) assert.Contains(t, buf.String(), "{key2 value2}") assert.Contains(t, buf.String(), "{key3 value3}") assert.NotContains(t, buf.String(), "{key1 value1}") } } func TestLevelLogger(t *testing.T) { var buf bytes.Buffer var bufLogger = NewLogger(&buf) expectedLines := map[LogLevel]int{ LogLevelAll: 6, LogLevelDebug: 6, LogLevelInfo: 4, LogLevelWarn: 2, LogLevelError: 1, LogLevelFatal: 0, } for level := LogLevelFatal; level >= LogLevelAll; level-- { buf.Reset() levelLogger := NewLevelLogger(bufLogger, level) for l := LogLevel(0); l <= LogLevelFatal; l++ { assert.Equal(t, level <= l, levelLogger.Enabled(l), "levelLogger.Enabled(%v) at %v", l, level) } levelLogger.Debug("debug") levelLogger.Debugf("debu%v", "g") levelLogger.Info("info") levelLogger.Infof("inf%v", "o") levelLogger.Warn("warn") levelLogger.Error("error") assert.Equal(t, expectedLines[level], bytes.Count(buf.Bytes(), []byte{'\n'})) } } ================================================ FILE: messages.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "time" "github.com/uber/tchannel-go/typed" ) // messageType defines a type of message type messageType byte const ( messageTypeInitReq messageType = 0x01 messageTypeInitRes messageType = 0x02 messageTypeCallReq messageType = 0x03 messageTypeCallRes messageType = 0x04 messageTypeCallReqContinue messageType = 0x13 messageTypeCallResContinue messageType = 0x14 messageTypeCancel messageType = 0xC0 messageTypePingReq messageType = 0xd0 messageTypePingRes messageType = 0xd1 messageTypeError messageType = 0xFF ) //go:generate stringer -type=messageType // message is the base interface for messages. Has an id and type, and knows // how to read and write onto a binary stream type message interface { // ID returns the id of the message ID() uint32 // messageType returns the type of the message messageType() messageType // read reads the message from a binary stream read(r *typed.ReadBuffer) error // write writes the message to a binary stream write(w *typed.WriteBuffer) error } type noBodyMsg struct{} func (noBodyMsg) read(r *typed.ReadBuffer) error { return nil } func (noBodyMsg) write(w *typed.WriteBuffer) error { return nil } // initParams are parameters to an initReq/InitRes type initParams map[string]string const ( // InitParamHostPort contains the host and port of the peer process InitParamHostPort = "host_port" // InitParamProcessName contains the name of the peer process InitParamProcessName = "process_name" // InitParamTChannelLanguage contains the library language. InitParamTChannelLanguage = "tchannel_language" // InitParamTChannelLanguageVersion contains the language build/runtime version. InitParamTChannelLanguageVersion = "tchannel_language_version" // InitParamTChannelVersion contains the library version. InitParamTChannelVersion = "tchannel_version" ) // initMessage is the base for messages in the initialization handshake type initMessage struct { id uint32 Version uint16 initParams initParams } func (m *initMessage) read(r *typed.ReadBuffer) error { m.Version = r.ReadUint16() m.initParams = initParams{} np := r.ReadUint16() for i := 0; i < int(np); i++ { k := r.ReadLen16String() v := r.ReadLen16String() m.initParams[k] = v } return r.Err() } func (m *initMessage) write(w *typed.WriteBuffer) error { w.WriteUint16(m.Version) w.WriteUint16(uint16(len(m.initParams))) for k, v := range m.initParams { w.WriteLen16String(k) w.WriteLen16String(v) } return w.Err() } func (m *initMessage) ID() uint32 { return m.id } // An initReq contains context information sent from an initiating peer type initReq struct { initMessage } func (m *initReq) messageType() messageType { return messageTypeInitReq } // An initRes contains context information returned to an initiating peer type initRes struct { initMessage } func (m *initRes) messageType() messageType { return messageTypeInitRes } // TransportHeaderName is a type for transport header names. type TransportHeaderName string func (cn TransportHeaderName) String() string { return string(cn) } // Known transport header keys for call requests. // Note: transport header names must be <= 16 bytes: // https://tchannel.readthedocs.io/en/latest/protocol/#transport-headers const ( // ArgScheme header specifies the format of the args. ArgScheme TransportHeaderName = "as" // CallerName header specifies the name of the service making the call. CallerName TransportHeaderName = "cn" // ClaimAtFinish header value is host:port specifying the instance to send a claim message // to when response is being sent. ClaimAtFinish TransportHeaderName = "caf" // ClaimAtStart header value is host:port specifying another instance to send a claim message // to when work is started. ClaimAtStart TransportHeaderName = "cas" // FailureDomain header describes a group of related requests to the same service that are // likely to fail in the same way if they were to fail. FailureDomain TransportHeaderName = "fd" // ShardKey header value is used by ringpop to deliver calls to a specific tchannel instance. ShardKey TransportHeaderName = "sk" // RetryFlags header specifies whether retry policies. RetryFlags TransportHeaderName = "re" // SpeculativeExecution header specifies the number of nodes on which to run the request. SpeculativeExecution TransportHeaderName = "se" // RoutingDelegate header identifies an intermediate service which knows // how to route the request to the intended recipient. RoutingDelegate TransportHeaderName = "rd" // RoutingKey header identifies a traffic group containing instances of the // requested service. A relay may use the routing key over the service if // it knows about traffic groups. RoutingKey TransportHeaderName = "rk" ) // transportHeaders are passed as part of a CallReq/CallRes type transportHeaders map[TransportHeaderName]string func (ch transportHeaders) read(r *typed.ReadBuffer) { nh := r.ReadSingleByte() for i := 0; i < int(nh); i++ { k := r.ReadLen8String() v := r.ReadLen8String() ch[TransportHeaderName(k)] = v } } func (ch transportHeaders) write(w *typed.WriteBuffer) { w.WriteSingleByte(byte(len(ch))) for k, v := range ch { w.WriteLen8String(k.String()) w.WriteLen8String(v) } } // A callReq for service type callReq struct { id uint32 TimeToLive time.Duration Tracing Span Headers transportHeaders Service string } func (m *callReq) ID() uint32 { return m.id } func (m *callReq) messageType() messageType { return messageTypeCallReq } func (m *callReq) read(r *typed.ReadBuffer) error { m.TimeToLive = time.Duration(r.ReadUint32()) * time.Millisecond m.Tracing.read(r) m.Service = r.ReadLen8String() m.Headers = transportHeaders{} m.Headers.read(r) return r.Err() } func (m *callReq) write(w *typed.WriteBuffer) error { w.WriteUint32(uint32(m.TimeToLive / time.Millisecond)) m.Tracing.write(w) w.WriteLen8String(m.Service) m.Headers.write(w) return w.Err() } // A callReqContinue is continuation of a previous callReq type callReqContinue struct { noBodyMsg id uint32 } func (c *callReqContinue) ID() uint32 { return c.id } func (c *callReqContinue) messageType() messageType { return messageTypeCallReqContinue } // ResponseCode to a CallReq type ResponseCode byte const ( responseOK ResponseCode = 0x00 responseApplicationError ResponseCode = 0x01 ) // callRes is a response to a CallReq type callRes struct { id uint32 ResponseCode ResponseCode Tracing Span Headers transportHeaders } func (m *callRes) ID() uint32 { return m.id } func (m *callRes) messageType() messageType { return messageTypeCallRes } func (m *callRes) read(r *typed.ReadBuffer) error { m.ResponseCode = ResponseCode(r.ReadSingleByte()) m.Tracing.read(r) m.Headers = transportHeaders{} m.Headers.read(r) return r.Err() } func (m *callRes) write(w *typed.WriteBuffer) error { w.WriteSingleByte(byte(m.ResponseCode)) m.Tracing.write(w) m.Headers.write(w) return w.Err() } // callResContinue is a continuation of a previous CallRes type callResContinue struct { id uint32 } func (c *callResContinue) ID() uint32 { return c.id } func (c *callResContinue) messageType() messageType { return messageTypeCallResContinue } func (c *callResContinue) read(r *typed.ReadBuffer) error { return nil } func (c *callResContinue) write(w *typed.WriteBuffer) error { return nil } // An errorMessage is a system-level error response to a request or a protocol level error type errorMessage struct { id uint32 errCode SystemErrCode tracing Span message string } func (m *errorMessage) ID() uint32 { return m.id } func (m *errorMessage) messageType() messageType { return messageTypeError } func (m *errorMessage) read(r *typed.ReadBuffer) error { m.errCode = SystemErrCode(r.ReadSingleByte()) m.tracing.read(r) m.message = r.ReadLen16String() return r.Err() } func (m *errorMessage) write(w *typed.WriteBuffer) error { w.WriteSingleByte(byte(m.errCode)) m.tracing.write(w) w.WriteLen16String(m.message) return w.Err() } func (m errorMessage) AsSystemError() error { // TODO(mmihic): Might be nice to return one of the well defined error types return NewSystemError(m.errCode, m.message) } // Error returns the error message from the converted func (m errorMessage) Error() string { return m.AsSystemError().Error() } type cancelMessage struct { id uint32 ttl uint32 // unused by tchannel-go, but part of the protocol. tracing Span message string } func (m *cancelMessage) ID() uint32 { return m.id } func (m *cancelMessage) messageType() messageType { return messageTypeCancel } func (m *cancelMessage) read(r *typed.ReadBuffer) error { m.ttl = r.ReadUint32() m.tracing.read(r) m.message = r.ReadLen16String() return r.Err() } func (m *cancelMessage) write(w *typed.WriteBuffer) error { w.WriteUint32(m.ttl) m.tracing.write(w) w.WriteLen16String(m.message) return w.Err() } func (m *cancelMessage) AsSystemError() error { return NewSystemError(ErrCodeCancelled, m.message) } type pingReq struct { noBodyMsg id uint32 } func (c *pingReq) ID() uint32 { return c.id } func (c *pingReq) messageType() messageType { return messageTypePingReq } // pingRes is a ping response to a protocol level ping request. type pingRes struct { noBodyMsg id uint32 } func (c *pingRes) ID() uint32 { return c.id } func (c *pingRes) messageType() messageType { return messageTypePingRes } func callReqSpan(f *Frame) Span { rdr := typed.NewReadBuffer(f.Payload[_spanIndex : _spanIndex+_spanLength]) var s Span s.read(rdr) return s } ================================================ FILE: messages_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "fmt" "testing" "time" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInitReq(t *testing.T) { req := initReq{ initMessage{ id: 0xDEADBEEF, Version: 0x02, initParams: initParams{ "lang": "en_US", "tz": "GMT", }, }, } assert.Equal(t, uint32(0xDEADBEEF), req.ID(), "ids do not match") assert.Equal(t, messageTypeInitReq, req.messageType(), "types do not match") assertRoundTrip(t, &req, &initReq{initMessage{id: 0xDEADBEEF}}) } func TestInitRes(t *testing.T) { res := initRes{ initMessage{ id: 0xDEADBEEF, Version: 0x04, initParams: initParams{ "lang": "en_US", "tz": "GMT", }, }, } assert.Equal(t, uint32(0xDEADBEEF), res.ID(), "ids do not match") assert.Equal(t, messageTypeInitRes, res.messageType(), "types do not match") assertRoundTrip(t, &res, &initRes{initMessage{id: 0xDEADBEEF}}) } func TestCallReq(t *testing.T) { r := callReq{ id: 0xDEADBEEF, TimeToLive: time.Second * 45, Tracing: Span{ traceID: 294390430934, parentID: 398348934, spanID: 12762782, flags: 0x01, }, Headers: transportHeaders{ "r": "c", "f": "d", }, Service: "udr", } assert.Equal(t, uint32(0xDEADBEEF), r.ID()) assert.Equal(t, messageTypeCallReq, r.messageType()) assertRoundTrip(t, &r, &callReq{id: 0xDEADBEEF}) } func TestCallReqContinue(t *testing.T) { r := callReqContinue{ id: 0xDEADBEEF, } assert.Equal(t, uint32(0xDEADBEEF), r.ID()) assert.Equal(t, messageTypeCallReqContinue, r.messageType()) assertRoundTrip(t, &r, &callReqContinue{id: 0xDEADBEEF}) } func TestCallRes(t *testing.T) { r := callRes{ id: 0xDEADBEEF, ResponseCode: responseApplicationError, Headers: transportHeaders{ "r": "c", "f": "d", }, Tracing: Span{ traceID: 294390430934, parentID: 398348934, spanID: 12762782, flags: 0x04, }, } assert.Equal(t, uint32(0xDEADBEEF), r.ID()) assert.Equal(t, messageTypeCallRes, r.messageType()) assertRoundTrip(t, &r, &callRes{id: 0xDEADBEEF}) } func TestCallResContinue(t *testing.T) { r := callResContinue{ id: 0xDEADBEEF, } assert.Equal(t, uint32(0xDEADBEEF), r.ID()) assert.Equal(t, messageTypeCallResContinue, r.messageType()) assertRoundTrip(t, &r, &callResContinue{id: 0xDEADBEEF}) } func TestErrorMessage(t *testing.T) { m := errorMessage{ errCode: ErrCodeBusy, message: "go away", } assert.Equal(t, messageTypeError, m.messageType()) assertRoundTrip(t, &m, &errorMessage{}) } func assertRoundTrip(t *testing.T, expected message, actual message) { w := typed.NewWriteBufferWithSize(1024) require.Nil(t, expected.write(w), fmt.Sprintf("error writing message %v", expected.messageType())) var b bytes.Buffer w.FlushTo(&b) r := typed.NewReadBuffer(b.Bytes()) require.Nil(t, actual.read(r), fmt.Sprintf("error reading message %v", expected.messageType())) assert.Equal(t, expected, actual, fmt.Sprintf("pre- and post-marshal %v do not match", expected.messageType())) } ================================================ FILE: messagetype_string.go ================================================ // Code generated by "stringer -type=messageType"; DO NOT EDIT. package tchannel import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[messageTypeInitReq-1] _ = x[messageTypeInitRes-2] _ = x[messageTypeCallReq-3] _ = x[messageTypeCallRes-4] _ = x[messageTypeCallReqContinue-19] _ = x[messageTypeCallResContinue-20] _ = x[messageTypeCancel-192] _ = x[messageTypePingReq-208] _ = x[messageTypePingRes-209] _ = x[messageTypeError-255] } const ( _messageType_name_0 = "messageTypeInitReqmessageTypeInitResmessageTypeCallReqmessageTypeCallRes" _messageType_name_1 = "messageTypeCallReqContinuemessageTypeCallResContinue" _messageType_name_2 = "messageTypeCancel" _messageType_name_3 = "messageTypePingReqmessageTypePingRes" _messageType_name_4 = "messageTypeError" ) var ( _messageType_index_0 = [...]uint8{0, 18, 36, 54, 72} _messageType_index_1 = [...]uint8{0, 26, 52} _messageType_index_3 = [...]uint8{0, 18, 36} ) func (i messageType) String() string { switch { case 1 <= i && i <= 4: i -= 1 return _messageType_name_0[_messageType_index_0[i]:_messageType_index_0[i+1]] case 19 <= i && i <= 20: i -= 19 return _messageType_name_1[_messageType_index_1[i]:_messageType_index_1[i+1]] case i == 192: return _messageType_name_2 case 208 <= i && i <= 209: i -= 208 return _messageType_name_3[_messageType_index_3[i]:_messageType_index_3[i+1]] case i == 255: return _messageType_name_4 default: return "messageType(" + strconv.FormatInt(int64(i), 10) + ")" } } ================================================ FILE: mex.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "errors" "fmt" "sync" "github.com/uber/tchannel-go/typed" "go.uber.org/atomic" "golang.org/x/net/context" ) var ( errDuplicateMex = errors.New("multiple attempts to use the message id") errMexShutdown = errors.New("mex has been shutdown") errMexSetShutdown = errors.New("mexset has been shutdown") errMexChannelFull = NewSystemError(ErrCodeBusy, "cannot send frame to message exchange channel") errUnexpectedFrameType = errors.New("unexpected frame received") ) const ( messageExchangeSetInbound = "inbound" messageExchangeSetOutbound = "outbound" // mexChannelBufferSize is the size of the message exchange channel buffer. mexChannelBufferSize = 2 ) type errNotifier struct { c chan struct{} err error notified atomic.Bool } func newErrNotifier() errNotifier { return errNotifier{c: make(chan struct{})} } // Notify will store the error and notify all waiters on c that there's an error. func (e *errNotifier) Notify(err error) error { // The code should never try to Notify(nil). if err == nil { panic("cannot Notify with no error") } // There may be some sort of race where we try to notify the mex twice. if !e.notified.CAS(false, true) { return fmt.Errorf("cannot broadcast error: %v, already have: %v", err, e.err) } e.err = err close(e.c) return nil } // checkErr returns previously notified errors (if any). func (e *errNotifier) checkErr() error { select { case <-e.c: return e.err default: return nil } } // A messageExchange tracks this Connections's side of a message exchange with a // peer. Each message exchange has a channel that can be used to receive // frames from the peer, and a Context that can controls when the exchange has // timed out or been cancelled. type messageExchange struct { recvCh chan *Frame errCh errNotifier ctx context.Context ctxCancel context.CancelFunc msgID uint32 msgType messageType mexset *messageExchangeSet framePool FramePool shutdownAtomic atomic.Bool errChNotified atomic.Bool } // checkError is called before waiting on the mex channels. // It returns any existing errors (timeout, cancellation, connection errors). func (mex *messageExchange) checkError() error { if err := mex.ctx.Err(); err != nil { return GetContextError(err) } return mex.errCh.checkErr() } // forwardPeerFrame forwards a frame from a peer to the message exchange, where // it can be pulled by whatever application thread is handling the exchange func (mex *messageExchange) forwardPeerFrame(frame *Frame) error { // We want a very specific priority here: // 1. Timeouts/cancellation (mex.ctx errors) // 2. Whether recvCh has buffer space (non-blocking select over mex.recvCh) // 3. Other mex errors (mex.errCh) // Which is why we check the context error only (instead of mex.checkError). // In the mex.errCh case, we do a non-blocking write to recvCh to prioritize it. if err := mex.ctx.Err(); err != nil { return GetContextError(err) } select { case mex.recvCh <- frame: return nil case <-mex.ctx.Done(): // Note: One slow reader processing a large request could stall the connection. // If we see this, we need to increase the recvCh buffer size. return GetContextError(mex.ctx.Err()) case <-mex.errCh.c: // Select will randomly choose a case, but we want to prioritize // sending a frame over the errCh. Try a non-blocking write. select { case mex.recvCh <- frame: return nil default: } return mex.errCh.err } } func (mex *messageExchange) handleCancel(_ *Frame) { if mex.ctxCancel != nil { mex.ctxCancel() } } func (mex *messageExchange) checkFrame(frame *Frame) error { if frame.Header.ID != mex.msgID { mex.mexset.log.WithFields( LogField{"msgId", mex.msgID}, LogField{"header", frame.Header}, ).Error("recvPeerFrame received msg with unexpected ID.") return errUnexpectedFrameType } return nil } // recvPeerFrame waits for a new frame from the peer, or until the context // expires or is cancelled func (mex *messageExchange) recvPeerFrame() (*Frame, error) { // We have to check frames/errors in a very specific order here: // 1. Timeouts/cancellation (mex.ctx errors) // 2. Any pending frames (non-blocking select over mex.recvCh) // 3. Other mex errors (mex.errCh) // Which is why we check the context error only (instead of mex.checkError)e // In the mex.errCh case, we do a non-blocking read from recvCh to prioritize it. if err := mex.ctx.Err(); err != nil { mex.onCtxErr(err) return nil, GetContextError(err) } select { case frame := <-mex.recvCh: if err := mex.checkFrame(frame); err != nil { return nil, err } return frame, nil case <-mex.ctx.Done(): mex.onCtxErr(mex.ctx.Err()) return nil, GetContextError(mex.ctx.Err()) case <-mex.errCh.c: // Select will randomly choose a case, but we want to prioritize // receiving a frame over errCh. Try a non-blocking read. select { case frame := <-mex.recvCh: if err := mex.checkFrame(frame); err != nil { return nil, err } return frame, nil default: } return nil, mex.errCh.err } } // recvPeerFrameOfType waits for a new frame of a given type from the peer, failing // if the next frame received is not of that type. // If an error frame is returned, then the errorMessage is returned as the error. func (mex *messageExchange) recvPeerFrameOfType(msgType messageType) (*Frame, error) { frame, err := mex.recvPeerFrame() if err != nil { return nil, err } switch frame.Header.messageType { case msgType: return frame, nil case messageTypeError: // If we read an error frame, we can release it once we deserialize it. defer mex.framePool.Release(frame) errMsg := errorMessage{ id: frame.Header.ID, } var rbuf typed.ReadBuffer rbuf.Wrap(frame.SizedPayload()) if err := errMsg.read(&rbuf); err != nil { return nil, err } return nil, errMsg default: // TODO(mmihic): Should be treated as a protocol error mex.mexset.log.WithFields( LogField{"header", frame.Header}, LogField{"expectedType", msgType}, LogField{"expectedID", mex.msgID}, ).Warn("Received unexpected frame.") return nil, errUnexpectedFrameType } } func (mex *messageExchange) onCtxErr(err error) { // On canceled contexts, we may need to send a cancel message. if err != context.Canceled { return } if onCancel := mex.mexset.onCancel; onCancel != nil { onCancel(mex.msgID) } } // shutdown shuts down the message exchange, removing it from the message // exchange set so that it cannot receive more messages from the peer. The // receive channel remains open, however, in case there are concurrent // goroutines sending to it. func (mex *messageExchange) shutdown() { // The reader and writer side can both hit errors and try to shutdown the mex, // so we ensure that it's only shut down once. if !mex.shutdownAtomic.CAS(false, true) { return } if mex.errChNotified.CAS(false, true) { mex.errCh.Notify(errMexShutdown) } mex.mexset.removeExchange(mex.msgID) return } // inboundExpired is called when an exchange is canceled or it times out, // but a handler may still be running in the background. Since the handler may // still write to the exchange, we cannot shutdown the exchange, but we should // remove it from the connection's exchange list. func (mex *messageExchange) inboundExpired() { mex.mexset.expireExchange(mex.msgID) } // A messageExchangeSet manages a set of active message exchanges. It is // mainly used to route frames from a peer to the appropriate messageExchange, // or to cancel or mark a messageExchange as being in error. Each Connection // maintains two messageExchangeSets, one to manage exchanges that it has // initiated (outbound), and another to manage exchanges that the peer has // initiated (inbound). The message-type specific handlers are responsible for // ensuring that their message exchanges are properly registered and removed // from the corresponding exchange set. type messageExchangeSet struct { sync.RWMutex log Logger name string onCancel func(id uint32) onRemoved func() onAdded func() // maps are mutable, and are protected by the mutex. exchanges map[uint32]*messageExchange expiredExchanges map[uint32]struct{} shutdown bool } // newMessageExchangeSet creates a new messageExchangeSet with a given name. func newMessageExchangeSet(log Logger, name string) *messageExchangeSet { return &messageExchangeSet{ name: name, log: log.WithFields(LogField{"exchange", name}), exchanges: make(map[uint32]*messageExchange), expiredExchanges: make(map[uint32]struct{}), } } // addExchange adds an exchange, it must be called with the mexset locked. func (mexset *messageExchangeSet) addExchange(mex *messageExchange) error { if mexset.shutdown { return errMexSetShutdown } if _, ok := mexset.exchanges[mex.msgID]; ok { return errDuplicateMex } mexset.exchanges[mex.msgID] = mex return nil } // newExchange creates and adds a new message exchange to this set func (mexset *messageExchangeSet) newExchange(ctx context.Context, ctxCancel context.CancelFunc, framePool FramePool, msgType messageType, msgID uint32, bufferSize int) (*messageExchange, error) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("Creating new %s message exchange for [%v:%d]", mexset.name, msgType, msgID) } mex := &messageExchange{ msgType: msgType, msgID: msgID, ctx: ctx, ctxCancel: ctxCancel, recvCh: make(chan *Frame, bufferSize), errCh: newErrNotifier(), mexset: mexset, framePool: framePool, } mexset.Lock() addErr := mexset.addExchange(mex) mexset.Unlock() if addErr != nil { logger := mexset.log.WithFields( LogField{"msgID", mex.msgID}, LogField{"msgType", mex.msgType}, LogField{"exchange", mexset.name}, ) if addErr == errMexSetShutdown { logger.Warn("Attempted to create new mex after mexset shutdown.") } else if addErr == errDuplicateMex { logger.Warn("Duplicate msg ID for active and new mex.") } return nil, addErr } mexset.onAdded() // TODO(mmihic): Put into a deadline ordered heap so we can garbage collected expired exchanges return mex, nil } // deleteExchange will delete msgID, and return whether it was found or whether it was // timed out. This method must be called with the lock. func (mexset *messageExchangeSet) deleteExchange(msgID uint32) (found, timedOut bool) { if _, found := mexset.exchanges[msgID]; found { delete(mexset.exchanges, msgID) return true, false } if _, expired := mexset.expiredExchanges[msgID]; expired { delete(mexset.expiredExchanges, msgID) return false, true } return false, false } // removeExchange removes a message exchange from the set, if it exists. func (mexset *messageExchangeSet) removeExchange(msgID uint32) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("Removing %s message exchange %d", mexset.name, msgID) } mexset.Lock() found, expired := mexset.deleteExchange(msgID) mexset.Unlock() if !found && !expired { mexset.log.WithFields( LogField{"msgID", msgID}, ).Error("Tried to remove exchange multiple times") return } // If the message exchange was found, then we perform clean up actions. // These clean up actions can only be run once per exchange. mexset.onRemoved() } // expireExchange is similar to removeExchange, but it marks the exchange as // expired. func (mexset *messageExchangeSet) expireExchange(msgID uint32) { mexset.log.Debugf( "Removing %s message exchange %d due to timeout, cancellation or blackhole", mexset.name, msgID, ) mexset.Lock() // TODO(aniketp): explore if cancel can be called everytime we expire an exchange found, expired := mexset.deleteExchange(msgID) if found || expired { // Record in expiredExchanges if we deleted the exchange. mexset.expiredExchanges[msgID] = struct{}{} } mexset.Unlock() if expired { mexset.log.WithFields(LogField{"msgID", msgID}).Info("Exchange expired already") } mexset.onRemoved() } func (mexset *messageExchangeSet) count() int { mexset.RLock() count := len(mexset.exchanges) mexset.RUnlock() return count } // forwardPeerFrame forwards a frame from the peer to the appropriate message // exchange func (mexset *messageExchangeSet) forwardPeerFrame(frame *Frame) error { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("forwarding %s %s", mexset.name, frame.Header) } mexset.RLock() mex := mexset.exchanges[frame.Header.ID] mexset.RUnlock() if mex == nil { // This is ok since the exchange might have expired or been cancelled mexset.log.WithFields( LogField{"frameHeader", frame.Header.String()}, LogField{"exchange", mexset.name}, ).Info("Received frame for unknown message exchange.") return nil } if err := mex.forwardPeerFrame(frame); err != nil { mexset.log.WithFields( LogField{"frameHeader", frame.Header.String()}, LogField{"frameSize", frame.Header.FrameSize()}, LogField{"exchange", mexset.name}, ErrField(err), ).Info("Failed to forward frame.") return err } return nil } func (mexset *messageExchangeSet) handleCancel(frame *Frame) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("handling cancel for %s", mexset.name, frame.Header) } mexset.RLock() mex := mexset.exchanges[frame.Header.ID] mexset.RUnlock() if mex == nil { // This is ok since the exchange might have expired. mexset.log.WithFields( LogField{"frameHeader", frame.Header.String()}, LogField{"exchange", mexset.name}, ).Info("Received cancel frame for unknown message exchange.") return } mex.handleCancel(frame) } // copyExchanges returns a copy of the exchanges if the exchange is active. // The caller must lock the mexset. func (mexset *messageExchangeSet) copyExchanges() (shutdown bool, exchanges map[uint32]*messageExchange) { if mexset.shutdown { return true, nil } exchangesCopy := make(map[uint32]*messageExchange, len(mexset.exchanges)) for k, mex := range mexset.exchanges { exchangesCopy[k] = mex } return false, exchangesCopy } // stopExchanges stops all message exchanges to unblock all waiters on the mex. // This should only be called on connection failures. func (mexset *messageExchangeSet) stopExchanges(err error) { if mexset.log.Enabled(LogLevelDebug) { mexset.log.Debugf("stopping %v exchanges due to error: %v", mexset.count(), err) } mexset.Lock() shutdown, exchanges := mexset.copyExchanges() mexset.shutdown = true mexset.Unlock() if shutdown { mexset.log.Debugf("mexset has already been shutdown") return } for _, mex := range exchanges { // When there's a connection failure, we want to notify blocked callers that the // call will fail, but we don't want to shutdown the exchange as only the // arg reader/writer should shutdown the exchange. Otherwise, our guarantee // on sendChRefs that there's no references to sendCh is violated since // readers/writers could still have a reference to sendCh even though // we shutdown the exchange and called Done on sendChRefs. if mex.errChNotified.CAS(false, true) { mex.errCh.Notify(err) } } } ================================================ FILE: mex_utils_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "strings" ) // CheckEmptyExchangesConn checks whether all exchanges for the given connection are empty. // If there are exchanges, a string with information about leftover exchanges is returned. func CheckEmptyExchangesConn(c *ConnectionRuntimeState) string { var errors []string checkExchange := func(e ExchangeSetRuntimeState) { if e.Count > 0 { errors = append(errors, fmt.Sprintf(" %v leftover %v exchanges", e.Name, e.Count)) for _, v := range e.Exchanges { errors = append(errors, fmt.Sprintf(" exchanges: %+v", v)) } } } checkExchange(c.InboundExchange) checkExchange(c.OutboundExchange) if len(errors) == 0 { return "" } return fmt.Sprintf("Connection %d has leftover exchanges:\n\t%v", c.ID, strings.Join(errors, "\n\t")) } // CheckEmptyExchangesConns checks that all exchanges for the given connections are empty. func CheckEmptyExchangesConns(connections []*ConnectionRuntimeState) string { var errors []string for _, c := range connections { if v := CheckEmptyExchangesConn(c); v != "" { errors = append(errors, v) } } return strings.Join(errors, "\n") } // CheckEmptyExchanges checks that all exchanges for the given channel are empty. // // TODO: Remove CheckEmptyExchanges and friends in favor of // testutils.TestServer's verification. func CheckEmptyExchanges(ch *Channel) string { state := ch.IntrospectState(&IntrospectionOptions{IncludeExchanges: true}) var connections []*ConnectionRuntimeState for _, peer := range state.RootPeers { for _, conn := range peer.InboundConnections { connections = append(connections, &conn) } for _, conn := range peer.OutboundConnections { connections = append(connections, &conn) } } return CheckEmptyExchangesConns(connections) } ================================================ FILE: outbound.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "time" "github.com/uber/tchannel-go/typed" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "golang.org/x/net/context" ) // maxMethodSize is the maximum size of arg1. const maxMethodSize = 16 * 1024 // beginCall begins an outbound call on the connection func (c *Connection) beginCall(ctx context.Context, serviceName, methodName string, callOptions *CallOptions) (*OutboundCall, error) { now := c.timeNow() switch state := c.readState(); state { case connectionActive: break case connectionStartClose, connectionInboundClosed, connectionClosed: return nil, ErrConnectionClosed default: return nil, errConnectionUnknownState{"beginCall", state} } deadline, ok := ctx.Deadline() if !ok { // This case is handled by validateCall, so we should // never get here. return nil, ErrTimeoutRequired } // If the timeToLive is less than a millisecond, it will be encoded as 0 on // the wire, hence we return a timeout immediately. timeToLive := deadline.Sub(now) if timeToLive < time.Millisecond { return nil, ErrTimeout } if err := ctx.Err(); err != nil { return nil, GetContextError(err) } requestID := c.NextMessageID() mex, err := c.outbound.newExchange(ctx, c.outboundCtxCancel, c.opts.FramePool, messageTypeCallReq, requestID, mexChannelBufferSize) if err != nil { return nil, err } // Close may have been called between the time we checked the state and us creating the exchange. if state := c.readState(); state != connectionActive { mex.shutdown() return nil, ErrConnectionClosed } // Note: We don't verify number of transport headers as the library doesn't // allow adding arbitrary headers. Ensure we never add >= 256 headers here. headers := transportHeaders{ CallerName: c.localPeerInfo.ServiceName, } callOptions.setHeaders(headers) if opts := currentCallOptions(ctx); opts != nil { opts.overrideHeaders(headers) } call := new(OutboundCall) call.mex = mex call.conn = c call.callReq = callReq{ id: requestID, Headers: headers, Service: serviceName, TimeToLive: timeToLive, } call.statsReporter = c.statsReporter call.createStatsTags(c.commonStatsTags, callOptions, methodName) call.log = c.log.WithFields(LogField{"Out-Call", requestID}) // TODO(mmihic): It'd be nice to do this without an fptr call.messageForFragment = func(initial bool) message { if initial { return &call.callReq } return new(callReqContinue) } call.contents = newFragmentingWriter(call.log, call, c.opts.ChecksumType.New()) response := new(OutboundCallResponse) response.startedAt = now response.timeNow = c.timeNow response.requestState = callOptions.RequestState response.mex = mex response.log = c.log.WithFields(LogField{"Out-Response", requestID}) response.span = c.startOutboundSpan(ctx, serviceName, methodName, call, now) response.messageForFragment = func(initial bool) message { if initial { return &response.callRes } return new(callResContinue) } response.contents = newFragmentingReader(response.log, response) response.statsReporter = call.statsReporter response.commonStatsTags = call.commonStatsTags call.response = response if err := call.writeMethod([]byte(methodName)); err != nil { return nil, err } return call, nil } func (c *Connection) outboundCtxCancel() { // outbound contexts are created by callers, can't cancel them. // However, we shouldn't be trying to cancel them, so log. c.log.Debug("unexpected cancel of outbound context") } // handleCallRes handles an incoming call req message, forwarding the // frame to the response channel waiting for it func (c *Connection) handleCallRes(frame *Frame) bool { if err := c.outbound.forwardPeerFrame(frame); err != nil { return true } return false } // handleCallResContinue handles an incoming call res continue message, // forwarding the frame to the response channel waiting for it func (c *Connection) handleCallResContinue(frame *Frame) bool { if err := c.outbound.forwardPeerFrame(frame); err != nil { return true } return false } // An OutboundCall is an active call to a remote peer. A client makes a call // by calling BeginCall on the Channel, writing argument content via // ArgWriter2() ArgWriter3(), and then reading reading response data via the // ArgReader2() and ArgReader3() methods on the Response() object. type OutboundCall struct { reqResWriter callReq callReq response *OutboundCallResponse statsReporter StatsReporter commonStatsTags map[string]string } // Response provides access to the call's response object, which can be used to // read response arguments func (call *OutboundCall) Response() *OutboundCallResponse { return call.response } // createStatsTags creates the common stats tags, if they are not already created. func (call *OutboundCall) createStatsTags(connectionTags map[string]string, callOptions *CallOptions, method string) { call.commonStatsTags = map[string]string{ "target-service": call.callReq.Service, } for k, v := range connectionTags { call.commonStatsTags[k] = v } if callOptions.Format != HTTP { call.commonStatsTags["target-endpoint"] = string(method) } } // writeMethod writes the method (arg1) to the call func (call *OutboundCall) writeMethod(method []byte) error { call.statsReporter.IncCounter("outbound.calls.send", call.commonStatsTags, 1) return NewArgWriter(call.arg1Writer()).Write(method) } // Arg2Writer returns a WriteCloser that can be used to write the second argument. // The returned writer must be closed once the write is complete. func (call *OutboundCall) Arg2Writer() (ArgWriter, error) { return call.arg2Writer() } // Arg3Writer returns a WriteCloser that can be used to write the last argument. // The returned writer must be closed once the write is complete. func (call *OutboundCall) Arg3Writer() (ArgWriter, error) { return call.arg3Writer() } // LocalPeer returns the local peer information for this call. func (call *OutboundCall) LocalPeer() LocalPeerInfo { return call.conn.localPeerInfo } // RemotePeer returns the remote peer information for this call. func (call *OutboundCall) RemotePeer() PeerInfo { return call.conn.RemotePeerInfo() } func (call *OutboundCall) doneSending() {} // An OutboundCallResponse is the response to an outbound call type OutboundCallResponse struct { reqResReader callRes callRes requestState *RequestState // startedAt is the time at which the outbound call was started. startedAt time.Time timeNow func() time.Time span opentracing.Span statsReporter StatsReporter commonStatsTags map[string]string } // ApplicationError returns true if the call resulted in an application level error // TODO(mmihic): In current implementation, you must have called Arg2Reader before this // method returns the proper value. We should instead have this block until the first // fragment is available, if the first fragment hasn't been received. func (response *OutboundCallResponse) ApplicationError() bool { // TODO(mmihic): Wait for first fragment return response.callRes.ResponseCode == responseApplicationError } // Format the format of the request from the ArgScheme transport header. func (response *OutboundCallResponse) Format() Format { return Format(response.callRes.Headers[ArgScheme]) } // Arg2Reader returns an ArgReader to read the second argument. // The ReadCloser must be closed once the argument has been read. func (response *OutboundCallResponse) Arg2Reader() (ArgReader, error) { var method []byte if err := NewArgReader(response.arg1Reader()).Read(&method); err != nil { return nil, err } return response.arg2Reader() } // Arg3Reader returns an ArgReader to read the last argument. // The ReadCloser must be closed once the argument has been read. func (response *OutboundCallResponse) Arg3Reader() (ArgReader, error) { return response.arg3Reader() } // handleError handles an error coming back from the peer. If the error is a // protocol level error, the entire connection will be closed. If the error is // a request specific error, it will be written to the request's response // channel and converted into a SystemError returned from the next reader or // access call. // The return value is whether the frame should be released immediately. func (c *Connection) handleError(frame *Frame) bool { errMsg := errorMessage{ id: frame.Header.ID, } rbuf := typed.NewReadBuffer(frame.SizedPayload()) if err := errMsg.read(rbuf); err != nil { c.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, ErrField(err), ).Warn("Unable to read error frame.") c.connectionError("parsing error frame", err) return true } if errMsg.errCode == ErrCodeProtocol { c.log.WithFields( LogField{"remotePeer", c.remotePeerInfo}, LogField{"error", errMsg.message}, ).Warn("Peer reported protocol error.") c.connectionError("received protocol error", errMsg.AsSystemError()) return true } if err := c.outbound.forwardPeerFrame(frame); err != nil { c.log.WithFields( LogField{"frameHeader", frame.Header.String()}, LogField{"id", errMsg.id}, LogField{"errorMessage", errMsg.message}, LogField{"errorCode", errMsg.errCode}, ErrField(err), ).Info("Failed to forward error frame.") return true } // If the frame was forwarded, then the other side is responsible for releasing the frame. return false } func cloneTags(tags map[string]string) map[string]string { newTags := make(map[string]string, len(tags)) for k, v := range tags { newTags[k] = v } return newTags } // doneReading shuts down the message exchange for this call. // For outgoing calls, the last message is reading the call response. func (response *OutboundCallResponse) doneReading(unexpected error) { now := response.timeNow() isSuccess := unexpected == nil && !response.ApplicationError() lastAttempt := isSuccess || !response.requestState.HasRetries(unexpected) // TODO how should this work with retries? if span := response.span; span != nil { if unexpected != nil { span.LogEventWithPayload("error", unexpected) } if !isSuccess && lastAttempt { ext.Error.Set(span, true) } span.FinishWithOptions(opentracing.FinishOptions{FinishTime: now}) } latency := now.Sub(response.startedAt) response.statsReporter.RecordTimer("outbound.calls.per-attempt.latency", response.commonStatsTags, latency) if lastAttempt { requestLatency := response.requestState.SinceStart(now, latency) response.statsReporter.RecordTimer("outbound.calls.latency", response.commonStatsTags, requestLatency) } if retryCount := response.requestState.RetryCount(); retryCount > 0 { retryTags := cloneTags(response.commonStatsTags) retryTags["retry-count"] = fmt.Sprint(retryCount) response.statsReporter.IncCounter("outbound.calls.retries", retryTags, 1) } if unexpected != nil { // TODO(prashant): Report the error code type as per metrics doc and enable. // response.statsReporter.IncCounter("outbound.calls.system-errors", response.commonStatsTags, 1) } else if response.ApplicationError() { // TODO(prashant): Figure out how to add "type" to tags, which TChannel does not know about. response.statsReporter.IncCounter("outbound.calls.per-attempt.app-errors", response.commonStatsTags, 1) if lastAttempt { response.statsReporter.IncCounter("outbound.calls.app-errors", response.commonStatsTags, 1) } } else { response.statsReporter.IncCounter("outbound.calls.success", response.commonStatsTags, 1) } response.mex.shutdown() } func validateCall(ctx context.Context, serviceName, methodName string, callOpts *CallOptions) error { if serviceName == "" { return ErrNoServiceName } if len(methodName) > maxMethodSize { return ErrMethodTooLarge } if _, ok := ctx.Deadline(); !ok { return ErrTimeoutRequired } return nil } ================================================ FILE: peer.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "container/heap" "errors" "strings" "sync" "time" "github.com/uber/tchannel-go/trand" "go.uber.org/atomic" "golang.org/x/net/context" ) var ( // ErrInvalidConnectionState indicates that the connection is not in a valid state. // This may be due to a race between selecting the connection and it closing, so // it is a network failure that can be retried. ErrInvalidConnectionState = NewSystemError(ErrCodeNetwork, "connection is in an invalid state") // ErrNoPeers indicates that there are no peers. ErrNoPeers = errors.New("no peers available") // ErrPeerNotFound indicates that the specified peer was not found. ErrPeerNotFound = errors.New("peer not found") // ErrNoNewPeers indicates that no previously unselected peer is available. ErrNoNewPeers = errors.New("no new peer available") peerRng = trand.NewSeeded() ) // Connectable is the interface used by peers to create connections. type Connectable interface { // Connect tries to connect to the given hostPort. Connect(ctx context.Context, hostPort string) (*Connection, error) // Logger returns the logger to use. Logger() Logger } // PeerList maintains a list of Peers. type PeerList struct { sync.RWMutex parent *RootPeerList peersByHostPort map[string]*peerScore peerHeap *peerHeap scoreCalculator ScoreCalculator lastSelected uint64 } func newPeerList(root *RootPeerList) *PeerList { return &PeerList{ parent: root, peersByHostPort: make(map[string]*peerScore), scoreCalculator: newPreferIncomingCalculator(), peerHeap: newPeerHeap(), } } // SetStrategy sets customized peer selection strategy. func (l *PeerList) SetStrategy(sc ScoreCalculator) { l.Lock() defer l.Unlock() l.scoreCalculator = sc for _, ps := range l.peersByHostPort { newScore := l.scoreCalculator.GetScore(ps.Peer) l.updatePeer(ps, newScore) } } // Siblings don't share peer lists (though they take care not to double-connect // to the same hosts). func (l *PeerList) newSibling() *PeerList { sib := newPeerList(l.parent) return sib } // Add adds a peer to the list if it does not exist, or returns any existing peer. func (l *PeerList) Add(hostPort string) *Peer { if ps, ok := l.exists(hostPort); ok { return ps.Peer } l.Lock() defer l.Unlock() if p, ok := l.peersByHostPort[hostPort]; ok { return p.Peer } p := l.parent.Add(hostPort) p.addSC() ps := newPeerScore(p, l.scoreCalculator.GetScore(p)) l.peersByHostPort[hostPort] = ps l.peerHeap.addPeer(ps) return p } // GetNew returns a new, previously unselected peer from the peer list, or nil, // if no new unselected peer can be found. func (l *PeerList) GetNew(prevSelected map[string]struct{}) (*Peer, error) { l.Lock() defer l.Unlock() if l.peerHeap.Len() == 0 { return nil, ErrNoPeers } // Select a peer, avoiding previously selected peers. If all peers have been previously // selected, then it's OK to repick them. peer := l.choosePeer(prevSelected, true /* avoidHost */) if peer == nil { peer = l.choosePeer(prevSelected, false /* avoidHost */) } if peer == nil { return nil, ErrNoNewPeers } return peer, nil } // Get returns a peer from the peer list, or nil if none can be found, // will avoid previously selected peers if possible. func (l *PeerList) Get(prevSelected map[string]struct{}) (*Peer, error) { peer, err := l.GetNew(prevSelected) if err == ErrNoNewPeers { l.Lock() peer = l.choosePeer(nil, false /* avoidHost */) l.Unlock() } else if err != nil { return nil, err } if peer == nil { return nil, ErrNoPeers } return peer, nil } // Remove removes a peer from the peer list. It returns an error if the peer cannot be found. // Remove does not affect connections to the peer in any way. func (l *PeerList) Remove(hostPort string) error { l.Lock() defer l.Unlock() p, ok := l.peersByHostPort[hostPort] if !ok { return ErrPeerNotFound } p.delSC() delete(l.peersByHostPort, hostPort) l.peerHeap.removePeer(p) return nil } func (l *PeerList) choosePeer(prevSelected map[string]struct{}, avoidHost bool) *Peer { var psPopList []*peerScore var ps *peerScore canChoosePeer := func(hostPort string) bool { if _, ok := prevSelected[hostPort]; ok { return false } if avoidHost { if _, ok := prevSelected[getHost(hostPort)]; ok { return false } } return true } size := l.peerHeap.Len() for i := 0; i < size; i++ { popped := l.peerHeap.popPeer() if canChoosePeer(popped.HostPort()) { ps = popped break } psPopList = append(psPopList, popped) } for _, p := range psPopList { heap.Push(l.peerHeap, p) } if ps == nil { return nil } l.peerHeap.pushPeer(ps) ps.chosenCount.Inc() return ps.Peer } // GetOrAdd returns a peer for the given hostPort, creating one if it doesn't yet exist. func (l *PeerList) GetOrAdd(hostPort string) *Peer { // TODO: remove calls to GetOrAdd, use Add instead return l.Add(hostPort) } // Copy returns a copy of the PeerList as a map from hostPort to peer. func (l *PeerList) Copy() map[string]*Peer { l.RLock() defer l.RUnlock() listCopy := make(map[string]*Peer) for k, v := range l.peersByHostPort { listCopy[k] = v.Peer } return listCopy } // Len returns the length of the PeerList. func (l *PeerList) Len() int { l.RLock() defer l.RUnlock() return l.peerHeap.Len() } // exists checks if a hostport exists in the peer list. func (l *PeerList) exists(hostPort string) (*peerScore, bool) { l.RLock() ps, ok := l.peersByHostPort[hostPort] l.RUnlock() return ps, ok } // getPeerScore is called to find the peer and its score from a host port key. // Note that at least a Read lock must be held to call this function. func (l *PeerList) getPeerScore(hostPort string) (*peerScore, uint64, bool) { ps, ok := l.peersByHostPort[hostPort] if !ok { return nil, 0, false } return ps, ps.score, ok } // onPeerChange is called when there is a change that may cause the peer's score to change. // The new score is calculated, and the peer heap is updated with the new score if the score changes. func (l *PeerList) onPeerChange(p *Peer) { l.RLock() ps, psScore, ok := l.getPeerScore(p.hostPort) sc := l.scoreCalculator l.RUnlock() if !ok { return } newScore := sc.GetScore(ps.Peer) if newScore == psScore { return } l.Lock() l.updatePeer(ps, newScore) l.Unlock() } // updatePeer is called to update the score of the peer given the existing score. // Note that a Write lock must be held to call this function. func (l *PeerList) updatePeer(ps *peerScore, newScore uint64) { if ps.score == newScore { return } ps.score = newScore l.peerHeap.updatePeer(ps) } // peerScore represents a peer and scoring for the peer heap. // It is not safe for concurrent access, it should only be used through the PeerList. type peerScore struct { *Peer // score according to the current peer list's ScoreCalculator. score uint64 // index of the peerScore in the peerHeap. Used to interact with container/heap. index int // order is the tiebreaker for when score is equal. It is set when a peer // is pushed to the heap based on peerHeap.order with jitter. order uint64 } func newPeerScore(p *Peer, score uint64) *peerScore { return &peerScore{ Peer: p, score: score, index: -1, } } // Peer represents a single autobahn service or client with a unique host:port. type Peer struct { sync.RWMutex channel Connectable hostPort string onStatusChanged func(*Peer) onClosedConnRemoved func(*Peer) // scCount is the number of subchannels that this peer is added to. scCount uint32 // connections are mutable, and are protected by the mutex. newConnLock sync.Mutex inboundConnections []*Connection outboundConnections []*Connection chosenCount atomic.Uint64 // onUpdate is a test-only hook. onUpdate func(*Peer) } func newPeer(channel Connectable, hostPort string, onStatusChanged func(*Peer), onClosedConnRemoved func(*Peer)) *Peer { if hostPort == "" { panic("Cannot create peer with blank hostPort") } if onStatusChanged == nil { onStatusChanged = noopOnStatusChanged } return &Peer{ channel: channel, hostPort: hostPort, onStatusChanged: onStatusChanged, onClosedConnRemoved: onClosedConnRemoved, } } // HostPort returns the host:port used to connect to this peer. func (p *Peer) HostPort() string { return p.hostPort } // getConn treats inbound and outbound connections as a single virtual list // that can be indexed. The peer must be read-locked. func (p *Peer) getConn(i int) *Connection { inboundLen := len(p.inboundConnections) if i < inboundLen { return p.inboundConnections[i] } return p.outboundConnections[i-inboundLen] } func (p *Peer) getActiveConnLocked() (*Connection, bool) { allConns := len(p.inboundConnections) + len(p.outboundConnections) if allConns == 0 { return nil, false } // We cycle through the connection list, starting at a random point // to avoid always choosing the same connection. var startOffset int if allConns > 1 { startOffset = peerRng.Intn(allConns) } for i := 0; i < allConns; i++ { connIndex := (i + startOffset) % allConns if conn := p.getConn(connIndex); conn.IsActive() { return conn, true } } return nil, false } // getActiveConn will randomly select an active connection. // TODO(prashant): Should we clear inactive connections? // TODO(prashant): Do we want some sort of scoring for connections? func (p *Peer) getActiveConn() (*Connection, bool) { p.RLock() conn, ok := p.getActiveConnLocked() p.RUnlock() return conn, ok } // GetConnection returns an active connection to this peer. If no active connections // are found, it will create a new outbound connection and return it. func (p *Peer) GetConnection(ctx context.Context) (*Connection, error) { if activeConn, ok := p.getActiveConn(); ok { return activeConn, nil } // Lock here to restrict new connection creation attempts to one goroutine p.newConnLock.Lock() defer p.newConnLock.Unlock() // Check active connections again in case someone else got ahead of us. if activeConn, ok := p.getActiveConn(); ok { return activeConn, nil } // No active connections, make a new outgoing connection. return p.Connect(ctx) } // getConnectionRelay gets a connection, and uses the given timeout to lazily // create a context if a new connection is required. func (p *Peer) getConnectionRelay(callTimeout, relayMaxConnTimeout time.Duration) (*Connection, error) { if conn, ok := p.getActiveConn(); ok { return conn, nil } // Lock here to restrict new connection creation attempts to one goroutine p.newConnLock.Lock() defer p.newConnLock.Unlock() // Check active connections again in case someone else got ahead of us. if activeConn, ok := p.getActiveConn(); ok { return activeConn, nil } // Use the lower timeout value of the call timeout and the relay connection timeout. timeout := callTimeout if timeout > relayMaxConnTimeout && relayMaxConnTimeout > 0 { timeout = relayMaxConnTimeout } // When the relay creates outbound connections, we don't want those services // to ever connect back to us and send us traffic. We hide the host:port // so that service instances on remote machines don't try to connect back // and don't try to send Hyperbahn traffic on this connection. ctx, cancel := NewContextBuilder(timeout).HideListeningOnOutbound().Build() defer cancel() return p.Connect(ctx) } // addSC adds a reference to a peer from a subchannel (e.g. peer list). func (p *Peer) addSC() { p.Lock() p.scCount++ p.Unlock() } // delSC removes a reference to a peer from a subchannel (e.g. peer list). func (p *Peer) delSC() { p.Lock() p.scCount-- p.Unlock() } // canRemove returns whether this peer can be safely removed from the root peer list. func (p *Peer) canRemove() bool { p.RLock() count := len(p.inboundConnections) + len(p.outboundConnections) + int(p.scCount) p.RUnlock() return count == 0 } // addConnection adds an active connection to the peer's connection list. // If a connection is not active, returns ErrInvalidConnectionState. func (p *Peer) addConnection(c *Connection, direction connectionDirection) error { conns := p.connectionsFor(direction) if c.readState() != connectionActive { return ErrInvalidConnectionState } p.Lock() *conns = append(*conns, c) p.Unlock() // Inform third parties that a peer gained a connection. p.onStatusChanged(p) return nil } func (p *Peer) connectionsFor(direction connectionDirection) *[]*Connection { if direction == inbound { return &p.inboundConnections } return &p.outboundConnections } // removeConnection will check remove the connection if it exists on connsPtr // and returns whether it removed the connection. func (p *Peer) removeConnection(connsPtr *[]*Connection, changed *Connection) bool { conns := *connsPtr for i, c := range conns { if c == changed { // Remove the connection by moving the last item forward, and slicing the list. last := len(conns) - 1 conns[i], conns[last] = conns[last], nil *connsPtr = conns[:last] return true } } return false } // connectionStateChanged is called when one of the peers' connections states changes. // All non-active connections are removed from the peer. The connection will // still be tracked by the channel until it's completely closed. func (p *Peer) connectionCloseStateChange(changed *Connection) { if changed.IsActive() { return } p.Lock() found := p.removeConnection(&p.inboundConnections, changed) if !found { found = p.removeConnection(&p.outboundConnections, changed) } p.Unlock() if found { p.onClosedConnRemoved(p) // Inform third parties that a peer lost a connection. p.onStatusChanged(p) } } // Connect adds a new outbound connection to the peer. func (p *Peer) Connect(ctx context.Context) (*Connection, error) { return p.channel.Connect(ctx, p.hostPort) } // BeginCall starts a new call to this specific peer, returning an OutboundCall that can // be used to write the arguments of the call. func (p *Peer) BeginCall(ctx context.Context, serviceName, methodName string, callOptions *CallOptions) (*OutboundCall, error) { if callOptions == nil { callOptions = defaultCallOptions } callOptions.RequestState.AddSelectedPeer(p.HostPort()) if err := validateCall(ctx, serviceName, methodName, callOptions); err != nil { return nil, err } conn, err := p.GetConnection(ctx) if err != nil { return nil, err } call, err := conn.beginCall(ctx, serviceName, methodName, callOptions) if err != nil { return nil, err } return call, err } // NumConnections returns the number of inbound and outbound connections for this peer. func (p *Peer) NumConnections() (inbound int, outbound int) { p.RLock() inbound = len(p.inboundConnections) outbound = len(p.outboundConnections) p.RUnlock() return inbound, outbound } // NumPendingOutbound returns the number of pending outbound calls. func (p *Peer) NumPendingOutbound() int { count := 0 p.RLock() for _, c := range p.outboundConnections { count += c.outbound.count() } for _, c := range p.inboundConnections { count += c.outbound.count() } p.RUnlock() return count } func (p *Peer) runWithConnections(f func(*Connection)) { p.RLock() for _, c := range p.inboundConnections { f(c) } for _, c := range p.outboundConnections { f(c) } p.RUnlock() } func (p *Peer) callOnUpdateComplete() { p.RLock() f := p.onUpdate p.RUnlock() if f != nil { f(p) } } func noopOnStatusChanged(*Peer) {} // isEphemeralHostPort returns if hostPort is the default ephemeral hostPort. func isEphemeralHostPort(hostPort string) bool { return hostPort == "" || hostPort == ephemeralHostPort || strings.HasSuffix(hostPort, ":0") } ================================================ FILE: peer_bench_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/require" ) func benchmarkGetConnection(b *testing.B, numIncoming, numOutgoing int) { ctx, cancel := NewContext(10 * time.Second) defer cancel() s1 := testutils.NewServer(b, nil) s2 := testutils.NewServer(b, nil) defer s1.Close() defer s2.Close() for i := 0; i < numOutgoing; i++ { _, err := s1.Connect(ctx, s2.PeerInfo().HostPort) require.NoError(b, err, "Connect from s1 -> s2 failed") } for i := 0; i < numIncoming; i++ { _, err := s2.Connect(ctx, s1.PeerInfo().HostPort) require.NoError(b, err, "Connect from s2 -> s1 failed") } peer := s1.Peers().GetOrAdd(s2.PeerInfo().HostPort) b.ResetTimer() for i := 0; i < b.N; i++ { peer.GetConnection(ctx) } } func BenchmarkGetConnection0In1Out(b *testing.B) { benchmarkGetConnection(b, 0, 1) } func BenchmarkGetConnection1In0Out(b *testing.B) { benchmarkGetConnection(b, 1, 0) } func BenchmarkGetConnection5In5Out(b *testing.B) { benchmarkGetConnection(b, 5, 5) } ================================================ FILE: peer_heap.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "container/heap" "math/rand" "github.com/uber/tchannel-go/trand" ) // peerHeap maintains a min-heap of peers based on the peers' score. All method // calls must be serialized externally. type peerHeap struct { peerScores []*peerScore rng *rand.Rand order uint64 } func newPeerHeap() *peerHeap { return &peerHeap{rng: trand.NewSeeded()} } func (ph peerHeap) Len() int { return len(ph.peerScores) } func (ph *peerHeap) Less(i, j int) bool { if ph.peerScores[i].score == ph.peerScores[j].score { return ph.peerScores[i].order < ph.peerScores[j].order } return ph.peerScores[i].score < ph.peerScores[j].score } func (ph peerHeap) Swap(i, j int) { ph.peerScores[i], ph.peerScores[j] = ph.peerScores[j], ph.peerScores[i] ph.peerScores[i].index = i ph.peerScores[j].index = j } // Push implements heap Push interface func (ph *peerHeap) Push(x interface{}) { n := len(ph.peerScores) item := x.(*peerScore) item.index = n ph.peerScores = append(ph.peerScores, item) } // Pop implements heap Pop interface func (ph *peerHeap) Pop() interface{} { old := *ph n := len(old.peerScores) item := old.peerScores[n-1] item.index = -1 // for safety ph.peerScores = old.peerScores[:n-1] return item } // updatePeer updates the score for the given peer. func (ph *peerHeap) updatePeer(peerScore *peerScore) { heap.Fix(ph, peerScore.index) } // removePeer remove peer at specific index. func (ph *peerHeap) removePeer(peerScore *peerScore) { heap.Remove(ph, peerScore.index) } // popPeer pops the top peer of the heap. func (ph *peerHeap) popPeer() *peerScore { return heap.Pop(ph).(*peerScore) } // pushPeer pushes the new peer into the heap. func (ph *peerHeap) pushPeer(peerScore *peerScore) { ph.order++ newOrder := ph.order // randRange will affect the deviation of peer's chosenCount randRange := ph.Len()/2 + 1 peerScore.order = newOrder + uint64(ph.rng.Intn(randRange)) heap.Push(ph, peerScore) } func (ph *peerHeap) swapOrder(i, j int) { if i == j { return } ph.peerScores[i].order, ph.peerScores[j].order = ph.peerScores[j].order, ph.peerScores[i].order heap.Fix(ph, i) heap.Fix(ph, j) } // AddPeer adds a peer to the peer heap. func (ph *peerHeap) addPeer(peerScore *peerScore) { ph.pushPeer(peerScore) // Pick a random element, and swap the order with that peerScore. r := ph.rng.Intn(ph.Len()) ph.swapOrder(peerScore.index, r) } // Exposed for testing purposes. func (ph *peerHeap) peek() *peerScore { return ph.peerScores[0] } ================================================ FILE: peer_heap_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "math" "math/rand" "testing" "time" "github.com/stretchr/testify/assert" ) func TestPeerHeap(t *testing.T) { const numPeers = 10 r := rand.New(rand.NewSource(time.Now().UnixNano())) peerHeap := newPeerHeap() peerScores := make([]*peerScore, numPeers) minScore := uint64(math.MaxInt64) for i := 0; i < numPeers; i++ { ps := newPeerScore(&Peer{}, uint64(r.Intn(numPeers*5))) peerScores[i] = ps if ps.score < minScore { minScore = ps.score } } for i := 0; i < numPeers; i++ { peerHeap.pushPeer(peerScores[i]) } assert.Equal(t, numPeers, peerHeap.Len(), "Incorrect peer heap numPeers") assert.Equal(t, minScore, peerHeap.peek().score, "peerHeap top peer is not minimum") lastScore := peerHeap.popPeer().score for i := 1; i < numPeers; i++ { assert.Equal(t, numPeers-i, peerHeap.Len(), "Incorrect peer heap numPeers") score := peerHeap.popPeer().score assert.True(t, score >= lastScore, "The order of the heap is invalid") lastScore = score } } ================================================ FILE: peer_internal_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "testing" "github.com/stretchr/testify/assert" ) func TestIsEphemeralHostPort(t *testing.T) { tests := []struct { hostPort string want bool }{ {"", true}, {ephemeralHostPort, true}, {"127.0.0.1:0", true}, {"10.1.1.1:0", true}, {"127.0.0.1:1", false}, {"10.1.1.1:1", false}, } for _, tt := range tests { got := isEphemeralHostPort(tt.hostPort) assert.Equal(t, tt.want, got, "Unexpected result for %q", tt.hostPort) } } ================================================ FILE: peer_strategies.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "math" // ScoreCalculator defines the interface to calculate the score. type ScoreCalculator interface { GetScore(p *Peer) uint64 } // ScoreCalculatorFunc is an adapter that allows functions to be used as ScoreCalculator type ScoreCalculatorFunc func(p *Peer) uint64 // GetScore calls the underlying function. func (f ScoreCalculatorFunc) GetScore(p *Peer) uint64 { return f(p) } type zeroCalculator struct{} func (zeroCalculator) GetScore(p *Peer) uint64 { return 0 } func newZeroCalculator() zeroCalculator { return zeroCalculator{} } type leastPendingCalculator struct{} func (leastPendingCalculator) GetScore(p *Peer) uint64 { inbound, outbound := p.NumConnections() if inbound+outbound == 0 { return math.MaxUint64 } return uint64(p.NumPendingOutbound()) } // newLeastPendingCalculator returns a strategy prefers any connected peer. // Within connected peers, least pending calls is used. Peers with less pending outbound calls // get a smaller score. func newLeastPendingCalculator() leastPendingCalculator { return leastPendingCalculator{} } type preferIncomingCalculator struct{} func (preferIncomingCalculator) GetScore(p *Peer) uint64 { inbound, outbound := p.NumConnections() if inbound+outbound == 0 { return math.MaxUint64 } numPendingOutbound := uint64(p.NumPendingOutbound()) if inbound == 0 { return math.MaxInt32 + numPendingOutbound } return numPendingOutbound } // newPreferIncomingCalculator returns a strategy that prefers peers with incoming connections. // The scoring tiers are: // Peers with incoming connections, peers with any connections, unconnected peers. // Within each tier, least pending calls is used. Peers with less pending outbound calls // get a smaller score. func newPreferIncomingCalculator() preferIncomingCalculator { return preferIncomingCalculator{} } ================================================ FILE: peer_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "fmt" "sort" "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/benchmark" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" ) func fakePeer(t *testing.T, ch *Channel, hostPort string) *Peer { ch.Peers().Add(hostPort) peer, err := ch.Peers().Get(nil) require.NoError(t, err, "Unexpected error getting peer from heap.") require.Equal(t, hostPort, peer.HostPort(), "Got unexpected peer.") in, out := peer.NumConnections() require.Equal(t, 0, in, "Expected new peer to have no incoming connections.") require.Equal(t, 0, out, "Expected new peer to have no outgoing connections.") return peer } func assertNumConnections(t *testing.T, peer *Peer, in, out int) { actualIn, actualOut := peer.NumConnections() assert.Equal(t, actualIn, in, "Expected %v incoming connection.", in) assert.Equal(t, actualOut, out, "Expected %v outgoing connection.", out) } func TestGetPeerNoPeer(t *testing.T) { ch := testutils.NewClient(t, nil) defer ch.Close() peer, err := ch.Peers().Get(nil) assert.Equal(t, ErrNoPeers, err, "Empty peer list should return error") assert.Nil(t, peer, "should not return peer") } func TestGetPeerSinglePeer(t *testing.T) { ch := testutils.NewClient(t, nil) defer ch.Close() ch.Peers().Add("1.1.1.1:1234") peer, err := ch.Peers().Get(nil) assert.NoError(t, err, "peer list should return contained element") assert.Equal(t, "1.1.1.1:1234", peer.HostPort(), "returned peer mismatch") } func TestPeerUpdatesLen(t *testing.T) { ch := testutils.NewClient(t, nil) defer ch.Close() assert.Zero(t, ch.Peers().Len()) for i := 1; i < 5; i++ { ch.Peers().Add(fmt.Sprintf("1.1.1.1:%d", i)) assert.Equal(t, ch.Peers().Len(), i) } for i := 4; i > 0; i-- { assert.Equal(t, ch.Peers().Len(), i) ch.Peers().Remove(fmt.Sprintf("1.1.1.1:%d", i)) } assert.Zero(t, ch.Peers().Len()) } func TestGetPeerAvoidPrevSelected(t *testing.T) { const ( peer1 = "1.1.1.1:1" peer2 = "2.2.2.2:2" peer3 = "3.3.3.3:3" peer4 = "3.3.3.3:4" ) ch := testutils.NewClient(t, nil) defer ch.Close() a, m := testutils.StrArray, testutils.StrMap tests := []struct { msg string peers []string prevSelected []string expected map[string]struct{} }{ { msg: "no prevSelected", peers: a(peer1), expected: m(peer1), }, { msg: "ignore single hostPort in prevSelected", peers: a(peer1, peer2), prevSelected: a(peer1), expected: m(peer2), }, { msg: "ignore multiple hostPorts in prevSelected", peers: a(peer1, peer2, peer3), prevSelected: a(peer1, peer2), expected: m(peer3), }, { msg: "only peer is in prevSelected", peers: a(peer1), prevSelected: a(peer1), expected: m(peer1), }, { msg: "all peers are in prevSelected", peers: a(peer1, peer2, peer3), prevSelected: a(peer1, peer2, peer3), expected: m(peer1, peer2, peer3), }, { msg: "prevSelected host should be ignored", peers: a(peer1, peer3, peer4), prevSelected: a(peer3), expected: m(peer1), }, { msg: "prevSelected only has single host", peers: a(peer3, peer4), prevSelected: a(peer3), expected: m(peer4), }, } for i, tt := range tests { peers := ch.GetSubChannel(fmt.Sprintf("test%d", i), Isolated).Peers() for _, p := range tt.peers { peers.Add(p) } rs := &RequestState{} for _, selected := range tt.prevSelected { rs.AddSelectedPeer(selected) } gotPeer, err := peers.Get(rs.PrevSelectedPeers()) if err != nil { t.Errorf("Got unexpected error selecting peer: %v", err) continue } newPeer, err := peers.GetNew(rs.PrevSelectedPeers()) if len(tt.peers) == len(tt.prevSelected) { if newPeer != nil || err != ErrNoNewPeers { t.Errorf("%s: newPeer should not have been found %v: %v\n", tt.msg, newPeer, err) } } else { if gotPeer != newPeer || err != nil { t.Errorf("%s: expected equal peers, got %v new %v: %v\n", tt.msg, gotPeer, newPeer, err) } } got := gotPeer.HostPort() if _, ok := tt.expected[got]; !ok { t.Errorf("%s: got unexpected peer, expected one of %v got %v\n Peers = %v PrevSelected = %v", tt.msg, tt.expected, got, tt.peers, tt.prevSelected) } } } func TestPeerRemoveClosedConnection(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { client := testutils.NewClient(t, nil) defer client.Close() p := client.Peers().Add(hostPort) c1, err := p.Connect(ctx) require.NoError(t, err, "Failed to connect") require.NoError(t, err, c1.Ping(ctx)) c2, err := p.Connect(ctx) require.NoError(t, err, "Failed to connect") require.NoError(t, err, c2.Ping(ctx)) require.NoError(t, c1.Close(), "Failed to close first connection") _, outConns := p.NumConnections() assert.Equal(t, 1, outConns, "Expected 1 remaining outgoing connection") c, err := p.GetConnection(ctx) require.NoError(t, err, "GetConnection failed") assert.Equal(t, c2, c, "Expected second active connection") }) } func TestPeerConnectCancelled(t *testing.T) { WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { ctx, cancel := NewContext(100 * time.Millisecond) cancel() _, err := ch.Connect(ctx, "10.255.255.1:1") require.Error(t, err, "Connect should fail") assert.EqualError(t, err, ErrRequestCancelled.Error(), "Unexpected error") }) } func TestPeerGetConnectionWithNoActiveConnections(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { client := testutils.NewClient(t, nil) defer client.Close() var ( wg sync.WaitGroup lock sync.Mutex conn *Connection concurrency = 10 p = client.Peers().Add(hostPort) ) for i := 0; i < concurrency; i++ { wg.Add(1) go func() { defer wg.Done() c, err := p.GetConnection(ctx) require.NoError(t, err, "GetConnection failed") lock.Lock() defer lock.Unlock() if conn == nil { conn = c } else { assert.Equal(t, conn, c, "Expected the same active connection") } }() } wg.Wait() _, outbound := p.NumConnections() assert.Equal(t, 1, outbound, "Expected 1 active outbound connetion") }) } func TestInboundEphemeralPeerRemoved(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() // No relay, since we look for the exact host:port in peer lists. opts := testutils.NewOpts().NoRelay() WithVerifiedServer(t, opts, func(ch *Channel, hostPort string) { client := testutils.NewClient(t, nil) assert.NoError(t, client.Ping(ctx, hostPort), "Ping to server failed") // Server should have a host:port in the root peers for the client. var clientHP string peers := ch.RootPeers().Copy() for k := range peers { clientHP = k } waitTillInboundEmpty(t, ch, clientHP, func() { client.Close() }) assert.Equal(t, ChannelClosed, client.State(), "Client should be closed") _, ok := ch.RootPeers().Get(clientHP) assert.False(t, ok, "server's root peers should remove peer for client on close") }) } func TestOutboundEphemeralPeerRemoved(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { outbound := testutils.NewServer(t, testutils.NewOpts().SetServiceName("asd ")) assert.NoError(t, ch.Ping(ctx, outbound.PeerInfo().HostPort), "Ping to outbound failed") outboundHP := outbound.PeerInfo().HostPort // Server should have a peer for hostPort that should be gone. waitTillNConnections(t, ch, outboundHP, 0, 0, func() { outbound.Close() }) assert.Equal(t, ChannelClosed, outbound.State(), "Outbound should be closed") _, ok := ch.RootPeers().Get(outboundHP) assert.False(t, ok, "server's root peers should remove outbound peer") }) } func TestOutboundPeerNotAdded(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() WithVerifiedServer(t, nil, func(server *Channel, hostPort string) { server.Register(raw.Wrap(newTestHandler(t)), "echo") ch := testutils.NewClient(t, nil) defer ch.Close() ch.Ping(ctx, hostPort) raw.Call(ctx, ch, hostPort, server.PeerInfo().ServiceName, "echo", nil, nil) peer, err := ch.Peers().Get(nil) assert.Equal(t, ErrNoPeers, err, "Ping should not add peers") assert.Nil(t, peer, "Expected no peer to be returned") }) } func TestRemovePeerNotFound(t *testing.T) { ch := testutils.NewClient(t, nil) defer ch.Close() peers := ch.Peers() peers.Add("1.1.1.1:1") assert.Error(t, peers.Remove("not-found"), "Remove should fa") assert.NoError(t, peers.Remove("1.1.1.1:1"), "Remove shouldn't fail for existing peer") } func TestPeerRemovedFromRootPeers(t *testing.T) { tests := []struct { addHostPort bool removeHostPort bool expectFound bool }{ { addHostPort: true, expectFound: true, }, { addHostPort: true, removeHostPort: true, expectFound: false, }, { addHostPort: false, expectFound: false, }, } ctx, cancel := NewContext(time.Second) defer cancel() for _, tt := range tests { opts := testutils.NewOpts().NoRelay() WithVerifiedServer(t, opts, func(server *Channel, hostPort string) { ch := testutils.NewServer(t, nil) clientHP := ch.PeerInfo().HostPort if tt.addHostPort { server.Peers().Add(clientHP) } assert.NoError(t, ch.Ping(ctx, hostPort), "Ping failed") if tt.removeHostPort { require.NoError(t, server.Peers().Remove(clientHP), "Failed to remove peer") } waitTillInboundEmpty(t, server, clientHP, func() { ch.Close() }) rootPeers := server.RootPeers() _, found := rootPeers.Get(clientHP) assert.Equal(t, tt.expectFound, found, "Peer found mismatch, addHostPort: %v", tt.addHostPort) }) } } func TestPeerSelectionConnClosed(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() WithVerifiedServer(t, nil, func(server *Channel, hostPort string) { client := testutils.NewServer(t, nil) defer client.Close() // Ping will create an outbound connection from client -> server. require.NoError(t, testutils.Ping(client, server), "Ping failed") waitTillInboundEmpty(t, server, client.PeerInfo().HostPort, func() { peer, ok := client.RootPeers().Get(server.PeerInfo().HostPort) require.True(t, ok, "Client has no peer for %v", server.PeerInfo()) conn, err := peer.GetConnection(ctx) require.NoError(t, err, "Failed to get a connection") conn.Close() }) // Make sure the closed connection is not used. for i := 0; i < 10; i++ { require.NoError(t, testutils.Ping(client, server), "Ping failed") } }) } func TestPeerSelectionPreferIncoming(t *testing.T) { tests := []struct { numIncoming, numOutgoing, numUnconnected int isolated bool expectedIncoming int expectedOutgoing int expectedUnconnected int }{ { numIncoming: 5, numOutgoing: 5, numUnconnected: 5, expectedIncoming: 5, }, { numOutgoing: 5, numUnconnected: 5, expectedOutgoing: 5, }, { numUnconnected: 5, expectedUnconnected: 5, }, { numIncoming: 5, numOutgoing: 5, numUnconnected: 5, isolated: true, expectedIncoming: 5, expectedOutgoing: 5, }, { numOutgoing: 5, numUnconnected: 5, isolated: true, expectedOutgoing: 5, }, { numIncoming: 5, numUnconnected: 5, isolated: true, expectedIncoming: 5, }, { numUnconnected: 5, isolated: true, expectedUnconnected: 5, }, } for _, tt := range tests { // We need to directly connect from the server to the client and verify // the exact peers. opts := testutils.NewOpts().NoRelay() WithVerifiedServer(t, opts, func(ch *Channel, hostPort string) { ctx, cancel := NewContext(time.Second) defer cancel() selectedIncoming := make(map[string]int) selectedOutgoing := make(map[string]int) selectedUnconnected := make(map[string]int) peers := ch.Peers() if tt.isolated { peers = ch.GetSubChannel("isolated", Isolated).Peers() } // 5 peers that make incoming connections to ch. for i := 0; i < tt.numIncoming; i++ { incoming, _, incomingHP := NewServer(t, &testutils.ChannelOpts{ServiceName: fmt.Sprintf("incoming%d", i)}) defer incoming.Close() assert.NoError(t, incoming.Ping(ctx, ch.PeerInfo().HostPort), "Ping failed") peers.Add(incomingHP) selectedIncoming[incomingHP] = 0 } // 5 random peers that don't have any connections. for i := 0; i < tt.numUnconnected; i++ { hp := fmt.Sprintf("1.1.1.1:1%d", i) peers.Add(hp) selectedUnconnected[hp] = 0 } // 5 random peers that we have outgoing connections to. for i := 0; i < tt.numOutgoing; i++ { outgoing, _, outgoingHP := NewServer(t, &testutils.ChannelOpts{ServiceName: fmt.Sprintf("outgoing%d", i)}) defer outgoing.Close() assert.NoError(t, ch.Ping(ctx, outgoingHP), "Ping failed") peers.Add(outgoingHP) selectedOutgoing[outgoingHP] = 0 } var mu sync.Mutex checkMap := func(m map[string]int, peer string) bool { mu.Lock() defer mu.Unlock() if _, ok := m[peer]; !ok { return false } m[peer]++ return true } numSelectedPeers := func(m map[string]int) int { count := 0 for _, v := range m { if v > 0 { count++ } } return count } peerCheck := func() { for i := 0; i < 100; i++ { peer, err := peers.Get(nil) if assert.NoError(t, err, "Peers.Get failed") { peerHP := peer.HostPort() inMap := checkMap(selectedIncoming, peerHP) || checkMap(selectedOutgoing, peerHP) || checkMap(selectedUnconnected, peerHP) assert.True(t, inMap, "Couldn't find peer %v in any of our maps", peerHP) } } } // Now select peers in parallel var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() peerCheck() }() } wg.Wait() assert.Equal(t, tt.expectedIncoming, numSelectedPeers(selectedIncoming), "Selected incoming mismatch: %v", selectedIncoming) assert.Equal(t, tt.expectedOutgoing, numSelectedPeers(selectedOutgoing), "Selected outgoing mismatch: %v", selectedOutgoing) assert.Equal(t, tt.expectedUnconnected, numSelectedPeers(selectedUnconnected), "Selected unconnected mismatch: %v", selectedUnconnected) }) } } type peerTest struct { t testing.TB channels []*Channel } // NewService will return a new server channel and the host port. func (pt *peerTest) NewService(t testing.TB, svcName, processName string) (*Channel, string) { opts := testutils.NewOpts().SetServiceName(svcName).SetProcessName(processName) ch := testutils.NewServer(t, opts) pt.channels = append(pt.channels, ch) return ch, ch.PeerInfo().HostPort } // CleanUp will clean up all channels started as part of the peer test. func (pt *peerTest) CleanUp() { for _, ch := range pt.channels { ch.Close() } } func TestPeerSelection(t *testing.T) { pt := &peerTest{t: t} defer pt.CleanUp() WithVerifiedServer(t, &testutils.ChannelOpts{ServiceName: "S1"}, func(ch *Channel, hostPort string) { doPing := func(ch *Channel) { ctx, cancel := NewContext(time.Second) defer cancel() assert.NoError(t, ch.Ping(ctx, hostPort), "Ping failed") } strategy, count := createScoreStrategy(0, 1) s2, _ := pt.NewService(t, "S2", "S2") defer s2.Close() s2.GetSubChannel("S1").Peers().SetStrategy(strategy) s2.GetSubChannel("S1").Peers().Add(hostPort) doPing(s2) assert.EqualValues(t, 4, count.Load(), "Expect 4 exchange updates: peer add, new conn, ping, pong") }) } func getAllPeers(t *testing.T, pl *PeerList) []string { prevSelected := make(map[string]struct{}) var got []string for { peer, err := pl.Get(prevSelected) require.NoError(t, err, "Peer.Get failed") hp := peer.HostPort() if _, ok := prevSelected[hp]; ok { break } prevSelected[hp] = struct{}{} got = append(got, hp) } return got } func reverse(s []string) { for i := 0; i < len(s)/2; i++ { j := len(s) - i - 1 s[i], s[j] = s[j], s[i] } } func TestIsolatedPeerHeap(t *testing.T) { const numPeers = 10 ch := testutils.NewClient(t, nil) defer ch.Close() ps1 := createSubChannelWNewStrategy(ch, "S1", numPeers, 1) ps2 := createSubChannelWNewStrategy(ch, "S2", numPeers, -1, Isolated) hostports := make([]string, numPeers) for i := 0; i < numPeers; i++ { hostports[i] = fmt.Sprintf("127.0.0.1:%d", i) ps1.Add(hostports[i]) ps2.Add(hostports[i]) } ps1Expected := append([]string(nil), hostports...) assert.Equal(t, ps1Expected, getAllPeers(t, ps1), "Unexpected peer order") ps2Expected := append([]string(nil), hostports...) reverse(ps2Expected) assert.Equal(t, ps2Expected, getAllPeers(t, ps2), "Unexpected peer order") } func TestPeerSelectionRanking(t *testing.T) { const numPeers = 10 const numIterations = 1000 // Selected is a map from rank -> [peer, count] // It tracks how often a peer gets selected at a specific rank. selected := make([]map[string]int, numPeers) for i := 0; i < numPeers; i++ { selected[i] = make(map[string]int) } for i := 0; i < numIterations; i++ { ch := testutils.NewClient(t, nil) defer ch.Close() ch.SetRandomSeed(int64(i * 100)) for i := 0; i < numPeers; i++ { hp := fmt.Sprintf("127.0.0.1:60%v", i) ch.Peers().Add(hp) } for i := 0; i < numPeers; i++ { peer, err := ch.Peers().Get(nil) require.NoError(t, err, "Peers.Get failed") selected[i][peer.HostPort()]++ } } for _, m := range selected { testDistribution(t, m, 50, 150) } } func createScoreStrategy(initial, delta int64) (calc ScoreCalculator, retCount *atomic.Uint64) { var ( count atomic.Uint64 score atomic.Uint64 ) return ScoreCalculatorFunc(func(p *Peer) uint64 { count.Add(1) return score.Add(uint64(delta)) }), &count } func createSubChannelWNewStrategy(ch *Channel, name string, initial, delta int64, opts ...SubChannelOption) *PeerList { strategy, _ := createScoreStrategy(initial, delta) sc := ch.GetSubChannel(name, opts...) ps := sc.Peers() ps.SetStrategy(strategy) return ps } func testDistribution(t testing.TB, counts map[string]int, min, max float64) { for k, v := range counts { if float64(v) < min || float64(v) > max { t.Errorf("Key %v has value %v which is out of range %v-%v", k, v, min, max) } } } // waitTillNConnetions will run f which should end up causing the peer with hostPort in ch // to have the specified number of inbound and outbound connections. // If the number of connections does not match after a second, the test is failed. func waitTillNConnections(t *testing.T, ch *Channel, hostPort string, inbound, outbound int, f func()) { peer, ok := ch.RootPeers().Get(hostPort) if !ok { return } var ( i = -1 o = -1 ) inboundEmpty := make(chan struct{}) var onUpdateOnce sync.Once onUpdate := func(p *Peer) { if i, o = p.NumConnections(); (i == inbound || inbound == -1) && (o == outbound || outbound == -1) { onUpdateOnce.Do(func() { close(inboundEmpty) }) } } peer.SetOnUpdate(onUpdate) f() select { case <-inboundEmpty: return case <-time.After(time.Second): t.Errorf("Timed out waiting for peer %v to have (in: %v, out: %v) connections, got (in: %v, out: %v)", hostPort, inbound, outbound, i, o) } } // waitTillInboundEmpty will run f which should end up causing the peer with hostPort in ch // to have 0 inbound connections. It will fail the test after a second. func waitTillInboundEmpty(t *testing.T, ch *Channel, hostPort string, f func()) { waitTillNConnections(t, ch, hostPort, 0, -1, f) } type peerSelectionTest struct { peerTest // numPeers is the number of peers added to the client channel. numPeers int // numAffinity is the number of affinity nodes. numAffinity int // numAffinityWithNoCall is the number of affinity nodes which doesn't send call req to client. numAffinityWithNoCall int // numConcurrent is the number of concurrent goroutine to make outbound calls. numConcurrent int // hasInboundCall is the bool flag to tell whether to have inbound calls from affinity nodes hasInboundCall bool servers []*Channel affinity []*Channel affinityWithNoCall []*Channel client *Channel } func (pt *peerSelectionTest) setup(t testing.TB) { pt.setupServers(t) pt.setupClient(t) pt.setupAffinity(t) } // setupServers will create numPeer servers, and register handlers on them. func (pt *peerSelectionTest) setupServers(t testing.TB) { pt.servers = make([]*Channel, pt.numPeers) // Set up numPeers servers. for i := 0; i < pt.numPeers; i++ { pt.servers[i], _ = pt.NewService(t, "server", fmt.Sprintf("server-%v", i)) pt.servers[i].Register(raw.Wrap(newTestHandler(pt.t)), "echo") } } func (pt *peerSelectionTest) setupAffinity(t testing.TB) { pt.affinity = make([]*Channel, pt.numAffinity) for i := range pt.affinity { pt.affinity[i] = pt.servers[i] } pt.affinityWithNoCall = make([]*Channel, pt.numAffinityWithNoCall) for i := range pt.affinityWithNoCall { pt.affinityWithNoCall[i] = pt.servers[i+pt.numAffinity] } var wg sync.WaitGroup wg.Add(pt.numAffinity) // Connect from the affinity nodes to the service. hostport := pt.client.PeerInfo().HostPort serviceName := pt.client.PeerInfo().ServiceName for _, affinity := range pt.affinity { go func(affinity *Channel) { affinity.Peers().Add(hostport) pt.makeCall(affinity.GetSubChannel(serviceName)) wg.Done() }(affinity) } wg.Wait() wg.Add(pt.numAffinityWithNoCall) for _, p := range pt.affinityWithNoCall { go func(p *Channel) { // use ping to build connection without sending call req. pt.sendPing(p, hostport) wg.Done() }(p) } wg.Wait() } func (pt *peerSelectionTest) setupClient(t testing.TB) { pt.client, _ = pt.NewService(t, "client", "client") pt.client.Register(raw.Wrap(newTestHandler(pt.t)), "echo") for _, server := range pt.servers { pt.client.Peers().Add(server.PeerInfo().HostPort) } } func (pt *peerSelectionTest) makeCall(sc *SubChannel) { ctx, cancel := NewContext(time.Second) defer cancel() _, _, _, err := raw.CallSC(ctx, sc, "echo", nil, nil) assert.NoError(pt.t, err, "raw.Call failed") } func (pt *peerSelectionTest) sendPing(ch *Channel, hostport string) { ctx, cancel := NewContext(time.Second) defer cancel() err := ch.Ping(ctx, hostport) assert.NoError(pt.t, err, "ping failed") } func (pt *peerSelectionTest) runStressSimple(b *testing.B) { var wg sync.WaitGroup wg.Add(pt.numConcurrent) // server outbound request sc := pt.client.GetSubChannel("server") for i := 0; i < pt.numConcurrent; i++ { go func(sc *SubChannel) { defer wg.Done() for j := 0; j < b.N; j++ { pt.makeCall(sc) } }(sc) } wg.Wait() } func (pt *peerSelectionTest) runStress() { numClock := pt.numConcurrent + pt.numAffinity clocks := make([]chan struct{}, numClock) for i := 0; i < numClock; i++ { clocks[i] = make(chan struct{}) } var wg sync.WaitGroup wg.Add(numClock) // helper that will make a request every n ticks. reqEveryNTicks := func(n int, sc *SubChannel, clock <-chan struct{}) { defer wg.Done() for { for i := 0; i < n; i++ { _, ok := <-clock if !ok { return } } pt.makeCall(sc) } } // server outbound request sc := pt.client.GetSubChannel("server") for i := 0; i < pt.numConcurrent; i++ { go reqEveryNTicks(1, sc, clocks[i]) } // affinity incoming requests if pt.hasInboundCall { serviceName := pt.client.PeerInfo().ServiceName for i, affinity := range pt.affinity { go reqEveryNTicks(1, affinity.GetSubChannel(serviceName), clocks[i+pt.numConcurrent]) } } tickAllClocks := func() { for i := 0; i < numClock; i++ { clocks[i] <- struct{}{} } } const tickNum = 10000 for i := 0; i < tickNum; i++ { if i%(tickNum/10) == 0 { fmt.Printf("Stress test progress: %v\n", 100*i/tickNum) } tickAllClocks() } for i := 0; i < numClock; i++ { close(clocks[i]) } wg.Wait() } // Run these commands before run the benchmark. // sudo sysctl w kern.maxfiles=50000 // ulimit n 50000 func BenchmarkSimplePeerHeapPerf(b *testing.B) { pt := &peerSelectionTest{ peerTest: peerTest{t: b}, numPeers: 1000, numConcurrent: 100, } defer pt.CleanUp() pt.setup(b) b.ResetTimer() pt.runStressSimple(b) } func TestPeerHeapPerf(t *testing.T) { CheckStress(t) tests := []struct { numserver int affinityRatio float64 numConcurrent int hasInboundCall bool }{ { numserver: 1000, affinityRatio: 0.1, numConcurrent: 5, hasInboundCall: true, }, { numserver: 1000, affinityRatio: 0.1, numConcurrent: 1, hasInboundCall: true, }, { numserver: 100, affinityRatio: 0.1, numConcurrent: 1, hasInboundCall: true, }, } for _, tt := range tests { peerHeapStress(t, tt.numserver, tt.affinityRatio, tt.numConcurrent, tt.hasInboundCall) } } func peerHeapStress(t testing.TB, numserver int, affinityRatio float64, numConcurrent int, hasInboundCall bool) { pt := &peerSelectionTest{ peerTest: peerTest{t: t}, numPeers: numserver, numConcurrent: numConcurrent, hasInboundCall: hasInboundCall, numAffinity: int(float64(numserver) * affinityRatio), numAffinityWithNoCall: 3, } defer pt.CleanUp() pt.setup(t) pt.runStress() validateStressTest(t, pt.client, pt.numAffinity, pt.numAffinityWithNoCall) } func validateStressTest(t testing.TB, server *Channel, numAffinity int, numAffinityWithNoCall int) { state := server.IntrospectState(&IntrospectionOptions{IncludeEmptyPeers: true}) countsByPeer := make(map[string]int) var counts []int for _, peer := range state.Peers { p, ok := state.RootPeers[peer.HostPort] assert.True(t, ok, "Missing peer.") if p.ChosenCount != 0 { countsByPeer[p.HostPort] = int(p.ChosenCount) counts = append(counts, int(p.ChosenCount)) } } // when number of affinity is zero, all peer suppose to be chosen. if numAffinity == 0 && numAffinityWithNoCall == 0 { numAffinity = len(state.Peers) } assert.EqualValues(t, len(countsByPeer), numAffinity+numAffinityWithNoCall, "Number of affinities nodes mismatch.") sort.Ints(counts) median := counts[len(counts)/2] testDistribution(t, countsByPeer, float64(median)*0.9, float64(median)*1.1) } func TestPeerSelectionAfterClosed(t *testing.T) { pt := &peerSelectionTest{ peerTest: peerTest{t: t}, numPeers: 5, numAffinity: 5, } defer pt.CleanUp() pt.setup(t) toClose := pt.affinity[pt.numAffinity-1] closedHP := toClose.PeerInfo().HostPort toClose.Logger().Debugf("About to Close %v", closedHP) waitTillInboundEmpty(t, pt.client, closedHP, func() { toClose.Close() }) for i := 0; i < 10*pt.numAffinity; i++ { peer, err := pt.client.Peers().Get(nil) assert.NoError(t, err, "Client failed to select a peer") assert.NotEqual(pt.t, closedHP, peer.HostPort(), "Closed peer shouldn't be chosen") } } func TestPeerScoreOnNewConnection(t *testing.T) { tests := []struct { message string connect func(s1, s2 *Channel) *Peer }{ { message: "outbound connection", connect: func(s1, s2 *Channel) *Peer { return s1.Peers().GetOrAdd(s2.PeerInfo().HostPort) }, }, { message: "inbound connection", connect: func(s1, s2 *Channel) *Peer { return s2.Peers().GetOrAdd(s1.PeerInfo().HostPort) }, }, } getScore := func(pl *PeerList) uint64 { peers := pl.IntrospectList(nil) require.Equal(t, 1, len(peers), "Wrong number of peers") return peers[0].Score } for _, tt := range tests { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() s1 := ts.Server() s2 := ts.NewServer(nil) s1.Peers().Add(s2.PeerInfo().HostPort) s2.Peers().Add(s1.PeerInfo().HostPort) initialScore := getScore(s1.Peers()) peer := tt.connect(s1, s2) conn, err := peer.GetConnection(ctx) require.NoError(t, err, "%v: GetConnection failed", tt.message) // When receiving an inbound connection, the outbound connect may return // before the inbound has updated the score, so we may need to retry. assert.True(t, testutils.WaitFor(time.Second, func() bool { connectedScore := getScore(s1.Peers()) return connectedScore < initialScore }), "%v: Expected connected peer score %v to be less than initial score %v", tt.message, getScore(s1.Peers()), initialScore) // Ping to ensure the connection has been added to peers on both sides. require.NoError(t, conn.Ping(ctx), "%v: Ping failed", tt.message) }) } } func TestConnectToPeerHostPortMismatch(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() // Set up a relay which will have a different host:port than the // real TChannel HostPort. relay, err := benchmark.NewTCPRawRelay([]string{ts.HostPort()}) require.NoError(t, err, "Failed to set up TCP relay") defer relay.Close() s2 := ts.NewServer(nil) for i := 0; i < 10; i++ { require.NoError(t, s2.Ping(ctx, relay.HostPort()), "Ping failed") } assert.Equal(t, 1, s2.IntrospectNumConnections(), "Unexpected number of connections") }) } // Test ensures that a closing connection does not count in NumConnections. // NumConnections should only include connections that be used to make calls. func TestPeerConnectionsClosing(t *testing.T) { // Disable the relay since we check the host:port directly. opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { unblock := make(chan struct{}) gotCall := make(chan struct{}) testutils.RegisterEcho(ts.Server(), func() { close(gotCall) <-unblock }) client := ts.NewServer(nil) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) }() // Wait for the call to be received before checking connections.. <-gotCall peer := ts.Server().Peers().GetOrAdd(client.PeerInfo().HostPort) in, out := peer.NumConnections() assert.Equal(t, 1, in+out, "Unexpected number of incoming connections") // Now when we try to close the channel, all the connections will change // state, and should no longer count as active connections. conn, err := peer.GetConnection(nil) require.NoError(t, err, "Failed to get connection") require.True(t, conn.IsActive(), "Connection should be active") ts.Server().Close() require.False(t, conn.IsActive(), "Connection should not be active after Close") in, out = peer.NumConnections() assert.Equal(t, 0, in+out, "Inactive connections should not be included in peer LAST") close(unblock) wg.Wait() }) } func BenchmarkAddPeers(b *testing.B) { for i := 0; i < b.N; i++ { ch := testutils.NewClient(b, nil) for i := 0; i < 1000; i++ { hp := fmt.Sprintf("127.0.0.1:%v", i) ch.Peers().Add(hp) } } } func TestPeerSelectionStrategyChange(t *testing.T) { const numPeers = 2 ch := testutils.NewClient(t, nil) defer ch.Close() for i := 0; i < numPeers; i++ { ch.Peers().Add(fmt.Sprintf("127.0.0.1:60%v", i)) } for _, score := range []uint64{1000, 2000} { ch.Peers().SetStrategy(createConstScoreStrategy(score)) for _, v := range ch.Peers().IntrospectList(nil) { assert.Equal(t, v.Score, score) } } } func createConstScoreStrategy(score uint64) (calc ScoreCalculator) { return ScoreCalculatorFunc(func(p *Peer) uint64 { return score }) } ================================================ FILE: peers/doc.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. /* Package peers provides helpers for managing TChannel peers. */ package peers ================================================ FILE: peers/prefer.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package peers import "github.com/uber/tchannel-go" type hrwScoreCalc struct { clientID uint32 } // NewHRWScorer returns a ScoreCalculator based on Rendezvous or Highest Random Weight // hashing. // It is useful for distributing load in peer-to-peer situations where we have // many clients picking from a set of servers with "sticky" semantics that will // spread load evenly as servers go down or new servers are added. // The clientID is used to score the servers, so each client should pass in // a unique client ID. func NewHRWScorer(clientID uint32) tchannel.ScoreCalculator { return &hrwScoreCalc{mod2_31(clientID)} } func (s *hrwScoreCalc) GetScore(p *tchannel.Peer) uint64 { server := mod2_31(fnv32a(p.HostPort())) // These constants are taken from W_rand2 in the Rendezvous paper: // http://www.eecs.umich.edu/techreports/cse/96/CSE-TR-316-96.pdf v := 1103515245*((1103515245*s.clientID+12345)^server) + 12345 return uint64(mod2_31(v)) } func mod2_31(v uint32) uint32 { return v & ((1 << 31) - 1) } // This is based on the standard library's fnv32a implementation. // We copy it for a couple of reasons: // 1. Avoid allocations to create a hash.Hash32 // 2. Avoid converting the []byte to a string (another allocation) since // the Hash32 interface only supports writing bytes. func fnv32a(s string) uint32 { const ( initial = 2166136261 prime = 16777619 ) hash := uint32(initial) for i := 0; i < len(s); i++ { hash ^= uint32(s[i]) hash *= prime } return hash } ================================================ FILE: peers/prefer_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package peers import ( "fmt" "hash/fnv" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "golang.org/x/net/context" ) func TestHRWScorerGetScore(t *testing.T) { client := testutils.NewClient(t, nil) peer := client.Peers().GetOrAdd("1.1.1.1") c1 := NewHRWScorer(1) c2 := NewHRWScorer(2) assert.NotEqual(t, c1.GetScore(peer), c2.GetScore(peer)) } func TestHRWScorerDistribution(t *testing.T) { const ( numClients = 1000 numServers = 10 ) ch := testutils.NewClient(t, nil) servers := make([]*tchannel.Peer, numServers) for i := range servers { servers[i] = ch.Peers().GetOrAdd(fmt.Sprintf("192.0.2.%v", i)) } serverSelected := make([]int, numServers) for i := 0; i < numClients; i++ { client := NewHRWScorer(uint32(i)) highestScore := uint64(0) highestServer := -1 for s, server := range servers { if score := client.GetScore(server); score > highestScore { highestScore = score highestServer = s } } serverSelected[highestServer]++ } // We can't get a perfect distribution, but should be within 20%. const ( expectCalls = numClients / numServers delta = expectCalls * 0.2 ) for serverIdx, count := range serverSelected { assert.InDelta(t, expectCalls, count, delta, "Server %v out of range", serverIdx) } } func countingServer(t *testing.T, opts *testutils.ChannelOpts) (*tchannel.Channel, *atomic.Int32) { var cnt atomic.Int32 server := testutils.NewServer(t, opts) testutils.RegisterEcho(server, func() { cnt.Inc() }) return server, &cnt } func TestHRWScorerIntegration(t *testing.T) { // Client pings to the server may cause errors during Close. sOpts := testutils.NewOpts().SetServiceName("svc").DisableLogVerification() s1, s1Count := countingServer(t, sOpts) s2, s2Count := countingServer(t, sOpts) client := testutils.NewClient(t, testutils.NewOpts().DisableLogVerification()) client.Peers().SetStrategy(NewHRWScorer(1)) client.Peers().Add(s1.PeerInfo().HostPort) client.Peers().Add(s2.PeerInfo().HostPort) // We want to call the raw echo function with TChannel retries. callEcho := func() error { ctx, cancel := tchannel.NewContext(time.Second) defer cancel() return client.RunWithRetry(ctx, func(ctx context.Context, rs *tchannel.RequestState) error { _, err := raw.CallV2(ctx, client.GetSubChannel("svc"), raw.CArgs{ Method: "echo", CallOptions: &tchannel.CallOptions{ RequestState: rs, }, }) return err }) } preferred, err := client.Peers().Get(nil) require.NoError(t, err, "Failed to get peer") if preferred.HostPort() == s2.PeerInfo().HostPort { // To make the test easier, we want "s1" to always be the preferred hostPort. s1, s1Count, s2, s2Count = s2, s2Count, s1, s1Count } // When we make 10 calls, all of them should go to s1 for i := 0; i < 10; i++ { err := callEcho() require.NoError(t, err, "Failed to call echo initially") } assert.EqualValues(t, 10, s1Count.Load(), "All calls should go to s1") // Stop s1, and ensure the client notices S1 has failed. s1.Close() testutils.WaitFor(time.Second, func() bool { if !s1.Closed() { return false } ctx, cancel := tchannel.NewContext(time.Second) defer cancel() return client.Ping(ctx, s1.PeerInfo().HostPort) != nil }) // Since s1 is stopped, next call should go to s2. err = callEcho() require.NoError(t, err, "Failed to call echo after s1 close") assert.EqualValues(t, 10, s1Count.Load(), "s1 should not get new calls as it's down") assert.EqualValues(t, 1, s2Count.Load(), "New call should go to s2") // And if s1 comes back, calls should resume to s1. s1Up := testutils.NewClient(t, sOpts) testutils.RegisterEcho(s1Up, func() { s1Count.Inc() }) err = s1Up.ListenAndServe(s1.PeerInfo().HostPort) require.NoError(t, err, "Failed to bring up a new channel as s1") for i := 0; i < 10; i++ { require.NoError(t, callEcho(), "Failed to call echo after s1 restarted") } assert.EqualValues(t, 20, s1Count.Load(), "Once s1 is up, calls should resume to s1") assert.EqualValues(t, 1, s2Count.Load(), "s2 should not receive calls after s1 restarted") } func stdFnv32a(s string) uint32 { h := fnv.New32a() h.Write([]byte(s)) return h.Sum32() } func TestFnv32a(t *testing.T) { tests := []string{ "", "1.1.1.1", "some-other-data", } for _, tt := range tests { assert.Equal(t, stdFnv32a(tt), fnv32a(tt), "Different results for %q", tt) } } func BenchmarkHrwScoreCalc(b *testing.B) { client := testutils.NewClient(b, nil) peer := client.Peers().GetOrAdd("1.1.1.1") c := NewHRWScorer(1) for i := 0; i < b.N; i++ { c.GetScore(peer) } } ================================================ FILE: pprof/pprof.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package pprof import ( "net/http" _ "net/http/pprof" // So pprof endpoints are registered on DefaultServeMux. "github.com/uber/tchannel-go" thttp "github.com/uber/tchannel-go/http" "golang.org/x/net/context" ) func serveHTTP(req *http.Request, response *tchannel.InboundCallResponse) { rw, finish := thttp.ResponseWriter(response) http.DefaultServeMux.ServeHTTP(rw, req) finish() } // Register registers pprof endpoints on the given registrar under _pprof. // The _pprof endpoint uses as-http and is a tunnel to the default serve mux. func Register(registrar tchannel.Registrar) { handler := func(ctx context.Context, call *tchannel.InboundCall) { req, err := thttp.ReadRequest(call) if err != nil { registrar.Logger().WithFields( tchannel.LogField{Key: "err", Value: err.Error()}, ).Warn("Failed to read HTTP request.") return } serveHTTP(req, call.Response()) } registrar.Register(tchannel.HandlerFunc(handler), "_pprof") } ================================================ FILE: pprof/pprof_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package pprof import ( "io/ioutil" "net/http" "testing" "time" "github.com/uber/tchannel-go" thttp "github.com/uber/tchannel-go/http" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestPProfEndpoint(t *testing.T) { ch := testutils.NewServer(t, nil) Register(ch) ctx, cancel := tchannel.NewContext(time.Second) defer cancel() req, err := http.NewRequest("GET", "/debug/pprof/block?debug=1", nil) require.NoError(t, err, "NewRequest failed") call, err := ch.BeginCall(ctx, ch.PeerInfo().HostPort, ch.ServiceName(), "_pprof", nil) require.NoError(t, err, "BeginCall failed") require.NoError(t, err, thttp.WriteRequest(call, req), "thttp.WriteRequest failed") response, err := thttp.ReadResponse(call.Response()) require.NoError(t, err, "ReadResponse failed") assert.Equal(t, http.StatusOK, response.StatusCode) body, err := ioutil.ReadAll(response.Body) if assert.NoError(t, err, "Read body failed") { assert.Contains(t, string(body), "contention", "Response does not contain expected string") } } ================================================ FILE: preinit_connection.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "encoding/binary" "fmt" "io" "math" "net" "strconv" "time" "golang.org/x/net/context" ) func (ch *Channel) outboundHandshake(ctx context.Context, c net.Conn, outboundHP string, events connectionEvents) (_ *Connection, err error) { defer setInitDeadline(ctx, c)() defer func() { err = ch.initError(c, outbound, 1, err) }() msg := &initReq{initMessage: ch.getInitMessage(ctx, 1)} if err := ch.writeMessage(c, msg); err != nil { return nil, err } res := &initRes{} id, err := ch.readMessage(c, res) if err != nil { return nil, err } if id != msg.id { return nil, NewSystemError(ErrCodeProtocol, "received initRes with invalid ID, wanted %v, got %v", msg.id, id) } if res.Version != CurrentProtocolVersion { return nil, unsupportedProtocolVersion(res.Version) } remotePeer, remotePeerAddress, err := parseRemotePeer(res.initParams, c.RemoteAddr()) if err != nil { return nil, NewWrappedSystemError(ErrCodeProtocol, err) } baseCtx := context.Background() if p := getTChannelParams(ctx); p != nil && p.connectBaseContext != nil { baseCtx = p.connectBaseContext } return ch.newConnection(baseCtx, c, 1 /* initialID */, outboundHP, remotePeer, remotePeerAddress, events), nil } func (ch *Channel) inboundHandshake(ctx context.Context, c net.Conn, events connectionEvents) (_ *Connection, err error) { id := uint32(math.MaxUint32) defer setInitDeadline(ctx, c)() defer func() { err = ch.initError(c, inbound, id, err) }() req := &initReq{} id, err = ch.readMessage(c, req) if err != nil { return nil, err } if req.Version < CurrentProtocolVersion { return nil, unsupportedProtocolVersion(req.Version) } remotePeer, remotePeerAddress, err := parseRemotePeer(req.initParams, c.RemoteAddr()) if err != nil { return nil, NewWrappedSystemError(ErrCodeProtocol, err) } res := &initRes{initMessage: ch.getInitMessage(ctx, id)} if err := ch.writeMessage(c, res); err != nil { return nil, err } return ch.newConnection(ctx, c, 0 /* initialID */, "" /* outboundHP */, remotePeer, remotePeerAddress, events), nil } func (ch *Channel) getInitParams() initParams { localPeer := ch.PeerInfo() return initParams{ InitParamHostPort: localPeer.HostPort, InitParamProcessName: localPeer.ProcessName, InitParamTChannelLanguage: localPeer.Version.Language, InitParamTChannelLanguageVersion: localPeer.Version.LanguageVersion, InitParamTChannelVersion: localPeer.Version.TChannelVersion, } } func (ch *Channel) getInitMessage(ctx context.Context, id uint32) initMessage { msg := initMessage{ id: id, Version: CurrentProtocolVersion, initParams: ch.getInitParams(), } if p := getTChannelParams(ctx); p != nil && p.hideListeningOnOutbound { msg.initParams[InitParamHostPort] = ephemeralHostPort } return msg } func (ch *Channel) initError(c net.Conn, connDir connectionDirection, id uint32, err error) error { if err == nil { return nil } ch.log.WithFields(LogFields{ {"connectionDirection", connDir}, {"localAddr", c.LocalAddr().String()}, {"remoteAddr", c.RemoteAddr().String()}, ErrField(err), }...).Error("Failed during connection handshake.") if ne, ok := err.(net.Error); ok && ne.Timeout() { err = ErrTimeout } if err == io.EOF { err = NewWrappedSystemError(ErrCodeNetwork, io.EOF) } ch.writeMessage(c, &errorMessage{ id: id, errCode: GetSystemErrorCode(err), message: err.Error(), }) c.Close() return err } func (ch *Channel) writeMessage(c net.Conn, msg message) error { frame := ch.connectionOptions.FramePool.Get() defer ch.connectionOptions.FramePool.Release(frame) if err := frame.write(msg); err != nil { return err } return frame.WriteOut(c) } func (ch *Channel) readMessage(c net.Conn, msg message) (uint32, error) { frame := ch.connectionOptions.FramePool.Get() defer ch.connectionOptions.FramePool.Release(frame) if err := frame.ReadIn(c); err != nil { return 0, err } if frame.Header.messageType != msg.messageType() { if frame.Header.messageType == messageTypeError { return frame.Header.ID, readError(frame) } return frame.Header.ID, NewSystemError(ErrCodeProtocol, "expected message type %v, got %v", msg.messageType(), frame.Header.messageType) } return frame.Header.ID, frame.read(msg) } func parseRemotePeer(p initParams, remoteAddr net.Addr) (PeerInfo, peerAddressComponents, error) { var ( remotePeer PeerInfo remotePeerAddress peerAddressComponents ok bool ) if remotePeer.HostPort, ok = p[InitParamHostPort]; !ok { return remotePeer, remotePeerAddress, fmt.Errorf("header %v is required", InitParamHostPort) } if remotePeer.ProcessName, ok = p[InitParamProcessName]; !ok { return remotePeer, remotePeerAddress, fmt.Errorf("header %v is required", InitParamProcessName) } // If the remote host:port is ephemeral, use the socket address as the // host:port and set IsEphemeral to true. if isEphemeralHostPort(remotePeer.HostPort) { remotePeer.HostPort = remoteAddr.String() remotePeer.IsEphemeral = true } remotePeer.Version.Language = p[InitParamTChannelLanguage] remotePeer.Version.LanguageVersion = p[InitParamTChannelLanguageVersion] remotePeer.Version.TChannelVersion = p[InitParamTChannelVersion] address := remotePeer.HostPort if sHost, sPort, err := net.SplitHostPort(address); err == nil { address = sHost if p, err := strconv.ParseUint(sPort, 10, 16); err == nil { remotePeerAddress.port = uint16(p) } } if address == "localhost" { remotePeerAddress.ipv4 = 127<<24 | 1 } else if ip := net.ParseIP(address); ip != nil { if ip4 := ip.To4(); ip4 != nil { remotePeerAddress.ipv4 = binary.BigEndian.Uint32(ip4) } else { remotePeerAddress.ipv6 = address } } else { remotePeerAddress.hostname = address } return remotePeer, remotePeerAddress, nil } func setInitDeadline(ctx context.Context, c net.Conn) func() { deadline, ok := ctx.Deadline() if !ok { deadline = time.Now().Add(5 * time.Second) } c.SetDeadline(deadline) return func() { c.SetDeadline(time.Time{}) } } func readError(frame *Frame) error { errMsg := &errorMessage{ id: frame.Header.ID, } if err := frame.read(errMsg); err != nil { return err } return errMsg.AsSystemError() } func unsupportedProtocolVersion(got uint16) error { return NewSystemError(ErrCodeProtocol, "unsupported protocol version %d from peer, expected %v", got, CurrentProtocolVersion) } ================================================ FILE: raw/call.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package raw import ( "errors" "golang.org/x/net/context" "github.com/uber/tchannel-go" ) // ErrAppError is returned if the application sets an error response. var ErrAppError = errors.New("application error") // ReadArgsV2 reads arg2 and arg3 from a reader. func ReadArgsV2(r tchannel.ArgReadable) ([]byte, []byte, error) { var arg2, arg3 []byte if err := tchannel.NewArgReader(r.Arg2Reader()).Read(&arg2); err != nil { return nil, nil, err } if err := tchannel.NewArgReader(r.Arg3Reader()).Read(&arg3); err != nil { return nil, nil, err } return arg2, arg3, nil } // WriteArgs writes the given arguments to the call, and returns the response args. func WriteArgs(call *tchannel.OutboundCall, arg2, arg3 []byte) ([]byte, []byte, *tchannel.OutboundCallResponse, error) { if err := tchannel.NewArgWriter(call.Arg2Writer()).Write(arg2); err != nil { return nil, nil, nil, err } if err := tchannel.NewArgWriter(call.Arg3Writer()).Write(arg3); err != nil { return nil, nil, nil, err } resp := call.Response() var respArg2 []byte if err := tchannel.NewArgReader(resp.Arg2Reader()).Read(&respArg2); err != nil { return nil, nil, nil, err } var respArg3 []byte if err := tchannel.NewArgReader(resp.Arg3Reader()).Read(&respArg3); err != nil { return nil, nil, nil, err } return respArg2, respArg3, resp, nil } // Call makes a call to the given hostPort with the given arguments and returns the response args. func Call(ctx context.Context, ch *tchannel.Channel, hostPort string, serviceName, method string, arg2, arg3 []byte) ([]byte, []byte, *tchannel.OutboundCallResponse, error) { call, err := ch.BeginCall(ctx, hostPort, serviceName, method, nil) if err != nil { return nil, nil, nil, err } return WriteArgs(call, arg2, arg3) } // CallSC makes a call using the given subcahnnel func CallSC(ctx context.Context, sc *tchannel.SubChannel, method string, arg2, arg3 []byte) ( []byte, []byte, *tchannel.OutboundCallResponse, error) { call, err := sc.BeginCall(ctx, method, nil) if err != nil { return nil, nil, nil, err } return WriteArgs(call, arg2, arg3) } // CArgs are the call arguments passed to CallV2. type CArgs struct { Method string Arg2 []byte Arg3 []byte CallOptions *tchannel.CallOptions } // CRes is the result of making a call. type CRes struct { Arg2 []byte Arg3 []byte AppError bool } // CallV2 makes a call and does not attempt any retries. func CallV2(ctx context.Context, sc *tchannel.SubChannel, cArgs CArgs) (*CRes, error) { call, err := sc.BeginCall(ctx, cArgs.Method, cArgs.CallOptions) if err != nil { return nil, err } arg2, arg3, res, err := WriteArgs(call, cArgs.Arg2, cArgs.Arg3) if err != nil { return nil, err } return &CRes{ Arg2: arg2, Arg3: arg3, AppError: res.ApplicationError(), }, nil } ================================================ FILE: raw/handler.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package raw import ( "golang.org/x/net/context" "github.com/uber/tchannel-go" ) // Handler is the interface for a raw handler. type Handler interface { // Handle is called on incoming calls, and contains all the arguments. // If an error is returned, it will set ApplicationError Arg3 will be the error string. Handle(ctx context.Context, args *Args) (*Res, error) OnError(ctx context.Context, err error) } // Args parses the arguments from an incoming call req. type Args struct { Caller string Format tchannel.Format Method string Arg2 []byte Arg3 []byte } // Res represents the response to an incoming call req. type Res struct { SystemErr error // IsErr is used to set an application error on the underlying call res. IsErr bool Arg2 []byte Arg3 []byte } // ReadArgs reads the *Args from the given call. func ReadArgs(call *tchannel.InboundCall) (*Args, error) { var args Args args.Caller = call.CallerName() args.Format = call.Format() args.Method = string(call.Method()) if err := tchannel.NewArgReader(call.Arg2Reader()).Read(&args.Arg2); err != nil { return nil, err } if err := tchannel.NewArgReader(call.Arg3Reader()).Read(&args.Arg3); err != nil { return nil, err } return &args, nil } // WriteResponse writes the given Res to the InboundCallResponse. func WriteResponse(response *tchannel.InboundCallResponse, resp *Res) error { if resp.SystemErr != nil { return response.SendSystemError(resp.SystemErr) } if resp.IsErr { if err := response.SetApplicationError(); err != nil { return err } } if err := tchannel.NewArgWriter(response.Arg2Writer()).Write(resp.Arg2); err != nil { return err } return tchannel.NewArgWriter(response.Arg3Writer()).Write(resp.Arg3) } // Wrap wraps a Handler as a tchannel.Handler that can be passed to tchannel.Register. func Wrap(handler Handler) tchannel.Handler { return tchannel.HandlerFunc(func(ctx context.Context, call *tchannel.InboundCall) { args, err := ReadArgs(call) if err != nil { handler.OnError(ctx, err) return } resp, err := handler.Handle(ctx, args) response := call.Response() if err != nil { resp = &Res{ SystemErr: err, } } if err := WriteResponse(response, resp); err != nil { handler.OnError(ctx, err) } }) } ================================================ FILE: relay/relay.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Package relay contains relaying interfaces for external use. // // These interfaces are currently unstable, and aren't covered by the API // backwards-compatibility guarantee. package relay import ( "context" "time" "github.com/uber/tchannel-go/thrift/arg2" ) // KeyVal is a key, val pair in arg2 type KeyVal struct { Key []byte Val []byte } // CallFrame is an interface that abstracts access to the call req frame. type CallFrame interface { // TTL is the TTL of the underlying frame TTL() time.Duration // Caller is the name of the originating service. Caller() []byte // Service is the name of the destination service. Service() []byte // Method is the name of the method being called. Method() []byte // RoutingDelegate is the name of the routing delegate, if any. RoutingDelegate() []byte // RoutingKey may refer to an alternate traffic group instead of the // traffic group identified by the service name. RoutingKey() []byte // Arg2StartOffset returns the offset from start of payload to the // beginning of Arg2 in bytes. Arg2StartOffset() int // Arg2EndOffset returns the offset from start of payload to the end of // Arg2 in bytes, and hasMore to indicate if there are more frames and // Arg3 has not started (i.e. Arg2 is fragmented). Arg2EndOffset() (_ int, hasMore bool) // Arg2Iterator returns the iterator for reading Arg2 key value pair // of TChannel-Thrift Arg Scheme. If no iterator is available, return // io.EOF. Arg2Iterator() (arg2.KeyValIterator, error) // Arg2Append appends a key/val pair to arg2 Arg2Append(key, val []byte) } // RespFrame is an interface that abstracts access to the CallRes frame type RespFrame interface { // OK indicates whether the call was successful OK() bool // ArgScheme returns the scheme of the arg ArgScheme() []byte // Arg2IsFragmented indicates whether arg2 runs over the first frame Arg2IsFragmented() bool // Arg2 returns the raw arg2 payload Arg2() []byte } // Conn contains information about the underlying connection. type Conn struct { // RemoteAddr is the remote address of the underlying TCP connection. RemoteAddr string // RemoteProcessName is the process name sent in the TChannel handshake. RemoteProcessName string // IsOutbound returns whether this connection is an outbound connection // initiated via the relay. IsOutbound bool // Context contains connection-specific context which can be accessed via // RelayHost.Start() Context context.Context } // RateLimitDropError is the error that should be returned from // RelayHosts.Get if the request should be dropped silently. // This is bit of a hack, because rate limiting of this nature isn't part of // the actual TChannel protocol. // The relayer will record that it has dropped the packet, but *won't* notify // the client. type RateLimitDropError struct{} func (e RateLimitDropError) Error() string { return "frame dropped silently due to rate limiting" } ================================================ FILE: relay/relaytest/func_host.go ================================================ package relaytest import ( "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/relay" ) // Ensure that the hostFunc implements tchannel.RelayHost and hostFuncPeer implements // tchannel.RelayCall var _ tchannel.RelayHost = (*hostFunc)(nil) var _ tchannel.RelayCall = (*hostFuncPeer)(nil) type hostFunc struct { ch *tchannel.Channel stats *MockStats fn func(relay.CallFrame, *relay.Conn) (string, error) } type hostFuncPeer struct { *MockCallStats peer *tchannel.Peer respFrame relay.RespFrame } // HostFunc wraps a given function to implement tchannel.RelayHost. func HostFunc(fn func(relay.CallFrame, *relay.Conn) (string, error)) tchannel.RelayHost { return &hostFunc{fn: fn} } func (hf *hostFunc) SetChannel(ch *tchannel.Channel) { hf.ch = ch hf.stats = NewMockStats() } func (hf *hostFunc) Start(cf relay.CallFrame, conn *relay.Conn) (tchannel.RelayCall, error) { var peer *tchannel.Peer peerHP, err := hf.fn(cf, conn) if peerHP != "" { peer = hf.ch.GetSubChannel(string(cf.Service())).Peers().GetOrAdd(peerHP) } // We still track stats if we failed to get a peer, so return the peer. return &hostFuncPeer{MockCallStats: hf.stats.Begin(cf), peer: peer}, err } func (hf *hostFunc) Stats() *MockStats { return hf.stats } func (p *hostFuncPeer) Destination() (*tchannel.Peer, bool) { return p.peer, p.peer != nil } func (p *hostFuncPeer) CallResponse(frame relay.RespFrame) { p.respFrame = frame } ================================================ FILE: relay/relaytest/mock_stats.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package relaytest import ( "fmt" "sort" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/relay" ) // MockCallStats is a testing spy for the CallStats interface. type MockCallStats struct { // Store ints and slices instead of bools and strings so that we can assert // the actual sequence of calls (in case we expect to call both Succeeded // and Failed). The real implementation will have the first writer win. succeeded int failedMsgs []string ended int sent int received int wg *sync.WaitGroup } // Succeeded marks the RPC as succeeded. func (m *MockCallStats) Succeeded() { m.succeeded++ } // Failed marks the RPC as failed for the provided reason. func (m *MockCallStats) Failed(reason string) { m.failedMsgs = append(m.failedMsgs, reason) } // SentBytes tracks the sent bytes. func (m *MockCallStats) SentBytes(size uint16) { m.sent += int(size) } // ReceivedBytes tracks the received bytes. func (m *MockCallStats) ReceivedBytes(size uint16) { m.received += int(size) } // End halts timer and metric collection for the RPC. func (m *MockCallStats) End() { m.ended++ m.wg.Done() } // FluentMockCallStats wraps the MockCallStats in a fluent API that's convenient for tests. type FluentMockCallStats struct { *MockCallStats } // Succeeded marks the RPC as succeeded. func (f *FluentMockCallStats) Succeeded() *FluentMockCallStats { f.MockCallStats.Succeeded() return f } // Failed marks the RPC as failed. func (f *FluentMockCallStats) Failed(reason string) *FluentMockCallStats { f.MockCallStats.Failed(reason) return f } // MockStats is a testing spy for the Stats interface. type MockStats struct { mu sync.Mutex wg sync.WaitGroup stats map[string][]*MockCallStats } // NewMockStats constructs a MockStats. func NewMockStats() *MockStats { return &MockStats{ stats: make(map[string][]*MockCallStats), } } // Begin starts collecting metrics for an RPC. func (m *MockStats) Begin(f relay.CallFrame) *MockCallStats { return m.Add(string(f.Caller()), string(f.Service()), string(f.Method())).MockCallStats } // Add explicitly adds a new call along an edge of the call graph. func (m *MockStats) Add(caller, callee, procedure string) *FluentMockCallStats { m.wg.Add(1) cs := &MockCallStats{wg: &m.wg} key := m.tripleToKey(caller, callee, procedure) m.mu.Lock() m.stats[key] = append(m.stats[key], cs) m.mu.Unlock() return &FluentMockCallStats{cs} } // AssertEqual asserts that two MockStats describe the same call graph. func (m *MockStats) AssertEqual(t testing.TB, expected *MockStats) { m.WaitForEnd() m.mu.Lock() defer m.mu.Unlock() expected.mu.Lock() defer expected.mu.Unlock() if assert.Equal(t, getEdges(expected.stats), getEdges(m.stats), "Found calls along unexpected edges.") { for edge := range expected.stats { m.assertEdgeEqual(t, expected, edge) } } } // WaitForEnd waits for all calls to End. func (m *MockStats) WaitForEnd() { m.wg.Wait() } func (m *MockStats) assertEdgeEqual(t testing.TB, expected *MockStats, edge string) { expectedCalls := expected.stats[edge] actualCalls := m.stats[edge] if assert.Equal(t, len(expectedCalls), len(actualCalls), "Unexpected number of calls along %s edge.", edge) { for i := range expectedCalls { m.assertCallEqual(t, expectedCalls[i], actualCalls[i]) } } } func (m *MockStats) assertCallEqual(t testing.TB, expected *MockCallStats, actual *MockCallStats) { // Revisit these assertions if we ever need to assert zero or many calls to // End. require.Equal(t, 1, expected.ended, "Expected call must assert exactly one call to End.") require.False( t, expected.succeeded <= 0 && len(expected.failedMsgs) == 0, "Expectation must indicate whether RPC should succeed or fail.", ) failed := !assert.Equal(t, expected.succeeded, actual.succeeded, "Unexpected number of successes.") failed = !assert.Equal(t, expected.failedMsgs, actual.failedMsgs, "Unexpected reasons for RPC failure.") || failed failed = !assert.Equal(t, expected.ended, actual.ended, "Unexpected number of calls to End.") || failed if failed { // The default testify output is often insufficient. t.Logf("\nExpected relayed stats were:\n\t%+v\nActual relayed stats were:\n\t%+v\n", expected, actual) } } func (m *MockStats) tripleToKey(caller, callee, procedure string) string { return fmt.Sprintf("%s->%s::%s", caller, callee, procedure) } func getEdges(m map[string][]*MockCallStats) []string { edges := make([]string, 0, len(m)) for k := range m { edges = append(edges, k) } sort.Strings(edges) return edges } // Map returns all stats as a map of key to int. // It waits for any ongoing calls to end first to avoid races. func (m *MockStats) Map() map[string]int { m.WaitForEnd() m.mu.Lock() defer m.mu.Unlock() stats := make(map[string]int) for k, calls := range m.stats { for _, call := range calls { name := k stats[name+".calls"]++ if call.ended > 0 { stats[name+".ended"]++ } if call.succeeded > 0 { stats[name+".succeeded"]++ } if len(call.failedMsgs) > 0 { failureName := name + ".failed-" + strings.Join(call.failedMsgs, ",") stats[failureName]++ } stats[name+".sent-bytes"] = call.sent stats[name+".received-bytes"] = call.received } } return stats } ================================================ FILE: relay/relaytest/stub_host.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package relaytest import ( "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/relay" ) // Ensure that the StubRelayHost implements tchannel.RelayHost and stubCall implements // tchannel.RelayCall var _ tchannel.RelayHost = (*StubRelayHost)(nil) var _ tchannel.RelayCall = (*stubCall)(nil) // StubRelayHost is a stub RelayHost for tests that backs peer selection to an // underlying channel using isolated subchannels and the default peer selection. type StubRelayHost struct { ch *tchannel.Channel stats *MockStats frameFn func(relay.CallFrame, *relay.Conn) respFrameFn func(relay.RespFrame) } type stubCall struct { *MockCallStats peer *tchannel.Peer respFrameFn func(relay.RespFrame) } // NewStubRelayHost creates a new stub RelayHost for tests. func NewStubRelayHost() *StubRelayHost { return &StubRelayHost{ stats: NewMockStats(), respFrameFn: func(_ relay.RespFrame) {}, } } // SetFrameFn sets a function to run on every frame. func (rh *StubRelayHost) SetFrameFn(f func(relay.CallFrame, *relay.Conn)) { rh.frameFn = f } // SetRespFrameFn sets a function to run on every frame. func (rh *StubRelayHost) SetRespFrameFn(f func(relay.RespFrame)) { rh.respFrameFn = f } // SetChannel is called by the channel after creation so we can // get a reference to the channels' peers. func (rh *StubRelayHost) SetChannel(ch *tchannel.Channel) { rh.ch = ch } // Start starts a new RelayCall for the given call on a specific connection. func (rh *StubRelayHost) Start(cf relay.CallFrame, conn *relay.Conn) (tchannel.RelayCall, error) { if rh.frameFn != nil { rh.frameFn(cf, conn) } // Get a peer from the subchannel. peer, err := rh.ch.GetSubChannel(string(cf.Service())).Peers().Get(nil) return &stubCall{ MockCallStats: rh.stats.Begin(cf), peer: peer, respFrameFn: rh.respFrameFn, }, err } // Add adds a service instance with the specified host:port. func (rh *StubRelayHost) Add(service, hostPort string) { rh.ch.GetSubChannel(service, tchannel.Isolated).Peers().GetOrAdd(hostPort) } // Stats returns the *MockStats tracked for this channel. func (rh *StubRelayHost) Stats() *MockStats { return rh.stats } // Destination returns the selected peer for this call. func (c *stubCall) Destination() (*tchannel.Peer, bool) { return c.peer, c.peer != nil } func (c *stubCall) CallResponse(frame relay.RespFrame) { c.respFrameFn(frame) } ================================================ FILE: relay.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "encoding/binary" "errors" "fmt" "io" "math" "sync" "time" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/typed" "go.uber.org/atomic" ) const ( // _defaultRelayMaxTombs is the default maximum number of tombs we'll accumulate // in a single relayItems. _defaultRelayMaxTombs = 3e4 // _relayTombTTL is the length of time we'll keep a tomb before GC'ing it. _relayTombTTL = 3 * time.Second // _defaultRelayMaxTimeout is the default max TTL for relayed calls. _defaultRelayMaxTimeout = 2 * time.Minute ) // Error strings. const ( _relayErrorNotFound = "relay-not-found" _relayErrorDestConnSlow = "relay-dest-conn-slow" _relayErrorSourceConnSlow = "relay-source-conn-slow" _relayArg2ModifyFailed = "relay-arg2-modify-failed" // _relayNoRelease indicates that the relayed frame should not be released immediately, since // relayed frames normally end up in a send queue where it is released afterward. However in some // cases, such as frames that are fragmented due to being mutated, we need to release the original // frame as it won't be relayed. _relayNoRelease = false _relayShouldRelease = true ) // TODO: Replace errFrameNotSent with more specific errors from Receive. var ( errRelayMethodFragmented = NewSystemError(ErrCodeBadRequest, "relay handler cannot receive fragmented calls") errFrameNotSent = NewSystemError(ErrCodeNetwork, "frame was not sent to remote side") errBadRelayHost = NewSystemError(ErrCodeDeclined, "bad relay host implementation") errUnknownID = errors.New("non-callReq for inactive ID") errNoNHInArg2 = errors.New("no nh in arg2") errFragmentedArg2WithAppend = errors.New("fragmented arg2 not supported for appends") errArg2ThriftOnly = errors.New("cannot inspect or modify arg2 for non-Thrift calls") ) type relayItem struct { remapID uint32 tomb bool isOriginator bool call RelayCall destination *Relayer span Span timeout *relayTimer mutatedChecksum Checksum } type relayItems struct { sync.RWMutex logger Logger timeouts *relayTimerPool maxTombs uint64 tombs uint64 items map[uint32]relayItem } type frameReceiver interface { Receive(f *Frame, fType frameType) (sent bool, failureReason string) } func newRelayItems(logger Logger, maxTombs uint64) *relayItems { if maxTombs == 0 { maxTombs = _defaultRelayMaxTombs } return &relayItems{ items: make(map[uint32]relayItem), logger: logger, maxTombs: maxTombs, } } func (ri *relayItem) reportRelayBytes(fType frameType, frameSize uint16) { if fType == requestFrame { ri.call.SentBytes(frameSize) } else { ri.call.ReceivedBytes(frameSize) } } // Count returns the number of non-tombstone items in the relay. func (r *relayItems) Count() int { r.RLock() n := len(r.items) - int(r.tombs) r.RUnlock() return n } // Get checks for a relay item by ID, and will stop the timeout with the // read lock held (to avoid a race between timeout stop and deletion). // It returns whether a timeout was stopped, and if the item was found. func (r *relayItems) Get(id uint32, stopTimeout bool) (_ relayItem, stopped bool, found bool) { r.RLock() defer r.RUnlock() item, ok := r.items[id] if !ok { return relayItem{}, false /* stopped */, false /* found */ } if !stopTimeout { return item, false /* stopped */, true /* found */ } return item, item.timeout.Stop(), true /* found */ } // Add adds a relay item. func (r *relayItems) Add(id uint32, item relayItem) { r.Lock() r.items[id] = item r.Unlock() } // Delete removes a relayItem completely (without leaving a tombstone). It // returns the deleted item, along with a bool indicating whether we completed a // relayed call. func (r *relayItems) Delete(id uint32) (relayItem, bool) { r.Lock() item, ok := r.items[id] if !ok { r.Unlock() r.logger.WithFields(LogField{"id", id}).Warn("Attempted to delete non-existent relay item.") return item, false } delete(r.items, id) if item.tomb { r.tombs-- } r.Unlock() item.timeout.Release() return item, !item.tomb } // Entomb sets the tomb bit on a relayItem and schedules a garbage collection. It // returns the entombed item, along with a bool indicating whether we completed // a relayed call. func (r *relayItems) Entomb(id uint32, deleteAfter time.Duration) (relayItem, bool) { r.Lock() if r.tombs > r.maxTombs { r.Unlock() r.logger.WithFields(LogField{"id", id}).Warn("Too many tombstones, deleting relay item immediately.") return r.Delete(id) } item, ok := r.items[id] if !ok { r.Unlock() r.logger.WithFields(LogField{"id", id}).Warn("Can't find relay item to entomb.") return item, false } if item.tomb { r.Unlock() r.logger.WithFields(LogField{"id", id}).Warn("Re-entombing a tombstone.") return item, false } r.tombs++ item.tomb = true r.items[id] = item r.Unlock() // TODO: We should be clearing these out in batches, rather than creating // individual timers for each item. time.AfterFunc(deleteAfter, func() { r.Delete(id) }) return item, true } type frameType int const ( requestFrame frameType = 0 responseFrame frameType = 1 ) var _ frameReceiver = (*Relayer)(nil) // A Relayer forwards frames. type Relayer struct { relayHost RelayHost maxTimeout time.Duration maxConnTimeout time.Duration // localHandlers is the set of service names that are handled by the local // channel. localHandler map[string]struct{} // outbound is the remapping for requests that originated on this // connection, and are outbound towards some other connection. // It stores remappings for all request frames read on this connection. outbound *relayItems // inbound is the remapping for requests that originated on some other // connection which was directed to this connection. // It stores remappings for all response frames read on this connection. inbound *relayItems // timeouts is the pool of timers used to track call timeouts. // It allows timer re-use, while allowing timers to be created and started separately. timeouts *relayTimerPool peers *RootPeerList conn *Connection relayConn *relay.Conn logger Logger pending atomic.Uint32 } // NewRelayer constructs a Relayer. func NewRelayer(ch *Channel, conn *Connection) *Relayer { r := &Relayer{ relayHost: ch.RelayHost(), maxTimeout: ch.relayMaxTimeout, maxConnTimeout: ch.relayMaxConnTimeout, localHandler: ch.relayLocal, outbound: newRelayItems(conn.log.WithFields(LogField{"relayItems", "outbound"}), ch.relayMaxTombs), inbound: newRelayItems(conn.log.WithFields(LogField{"relayItems", "inbound"}), ch.relayMaxTombs), peers: ch.RootPeers(), conn: conn, relayConn: &relay.Conn{ RemoteAddr: conn.conn.RemoteAddr().String(), RemoteProcessName: conn.RemotePeerInfo().ProcessName, IsOutbound: conn.connDirection == outbound, Context: conn.baseContext, }, logger: conn.log, } r.timeouts = newRelayTimerPool(r.timeoutRelayItem, ch.relayTimerVerify) return r } // Relay is called for each frame that is read on the connection. func (r *Relayer) Relay(f *Frame) (shouldRelease bool, _ error) { if f.messageType() != messageTypeCallReq { shouldRelease, err := r.handleNonCallReq(f) if err == errUnknownID { // This ID may be owned by an outgoing call, so check the outbound // message exchange, and if it succeeds, then the frame has been // handled successfully. if err := r.conn.outbound.forwardPeerFrame(f); err == nil { return _relayNoRelease, nil } } return shouldRelease, err } cr, err := newLazyCallReq(f) if err != nil { return _relayNoRelease, err } return r.handleCallReq(cr) } // Receive receives frames intended for this connection. // It returns whether the frame was sent and a reason for failure if it failed. func (r *Relayer) Receive(f *Frame, fType frameType) (sent bool, failureReason string) { id := f.Header.ID // If we receive a response frame, we expect to find that ID in our outbound. // If we receive a request frame, we expect to find that ID in our inbound. items := r.receiverItems(fType) finished := finishesCall(f) // Stop the timeout if the call if finished. item, stopped, ok := items.Get(id, finished /* stopTimeout */) if !ok { r.logger.WithFields( LogField{"id", id}, ).Warn("Received a frame without a RelayItem.") return false, _relayErrorNotFound } if item.tomb || (finished && !stopped) { // Item has previously timed out, or is in the process of timing out. // TODO: metrics for late-arriving frames. return true, "" } // call res frames don't include the OK bit, so we can't wait until the last // frame of a relayed RPC to determine if the call succeeded. if fType == responseFrame || f.messageType() == messageTypeCancel { // If we've gotten a response frame, we're the originating relayer and // should handle stats. if succeeded, failMsg := determinesCallSuccess(f); succeeded { item.call.Succeeded() } else if len(failMsg) > 0 { item.call.Failed(failMsg) } } select { case r.conn.sendCh <- f: default: // Buffer is full, so drop this frame and cancel the call. // Since this is typically due to the send buffer being full, get send buffer // usage + limit and add that to the log. sendBuf, sendBufLimit, sendBufErr := r.conn.sendBufSize() now := r.conn.timeNow().UnixNano() logFields := []LogField{ {"id", id}, {"destConnSendBufferCurrent", sendBuf}, {"destConnSendBufferLimit", sendBufLimit}, {"sendChQueued", len(r.conn.sendCh)}, {"sendChCapacity", cap(r.conn.sendCh)}, {"lastActivityRead", r.conn.lastActivityRead.Load()}, {"lastActivityWrite", r.conn.lastActivityRead.Load()}, {"sinceLastActivityRead", time.Duration(now - r.conn.lastActivityRead.Load()).String()}, {"sinceLastActivityWrite", time.Duration(now - r.conn.lastActivityWrite.Load()).String()}, } if sendBufErr != nil { logFields = append(logFields, LogField{"destConnSendBufferError", sendBufErr.Error()}) } r.logger.WithFields(logFields...).Warn("Dropping call due to slow connection.") items := r.receiverItems(fType) err := _relayErrorDestConnSlow // If we're dealing with a response frame, then the client is slow. if fType == responseFrame { err = _relayErrorSourceConnSlow } r.failRelayItem(items, id, err, errFrameNotSent) return false, err } if finished { r.finishRelayItem(items, id) } return true, "" } func (r *Relayer) canHandleNewCall() (bool, connectionState) { var ( canHandle bool curState connectionState ) r.conn.withStateRLock(func() error { curState = r.conn.state canHandle = curState == connectionActive if canHandle { r.pending.Inc() } return nil }) return canHandle, curState } func (r *Relayer) getDestination(f *lazyCallReq, call RelayCall) (*Connection, bool, error) { if _, _, ok := r.outbound.Get(f.Header.ID, false /* stopTimeout */); ok { r.logger.WithFields( LogField{"id", f.Header.ID}, LogField{"source", string(f.Caller())}, LogField{"dest", string(f.Service())}, LogField{"method", string(f.Method())}, ).Warn("Received duplicate callReq.") call.Failed(ErrCodeProtocol.relayMetricsKey()) // TODO: this is a protocol error, kill the connection. return nil, false, errors.New("callReq with already active ID") } // Get the destination peer, ok := call.Destination() if !ok { call.Failed("relay-bad-relay-host") r.conn.SendSystemError(f.Header.ID, f.Span(), errBadRelayHost) return nil, false, errBadRelayHost } remoteConn, err := peer.getConnectionRelay(f.TTL(), r.maxConnTimeout) if err != nil { r.logger.WithFields( ErrField(err), LogField{"source", string(f.Caller())}, LogField{"dest", string(f.Service())}, LogField{"method", string(f.Method())}, LogField{"selectedPeer", peer}, ).Warn("Failed to connect to relay host.") call.Failed("relay-connection-failed") r.conn.SendSystemError(f.Header.ID, f.Span(), NewWrappedSystemError(ErrCodeNetwork, err)) return nil, false, nil } return remoteConn, true, nil } func (r *Relayer) handleCallReq(f *lazyCallReq) (shouldRelease bool, _ error) { if handled := r.handleLocalCallReq(f); handled { return _relayNoRelease, nil } call, err := r.relayHost.Start(f, r.relayConn) if err != nil { // If we have a RateLimitDropError we record the statistic, but // we *don't* send an error frame back to the client. if _, silentlyDrop := err.(relay.RateLimitDropError); silentlyDrop { if call != nil { call.Failed("relay-dropped") call.End() } return _relayShouldRelease, nil } if _, ok := err.(SystemError); !ok { err = NewSystemError(ErrCodeDeclined, err.Error()) } if call != nil { call.Failed(GetSystemErrorCode(err).relayMetricsKey()) call.End() } r.conn.SendSystemError(f.Header.ID, f.Span(), err) // If the RelayHost returns a protocol error, close the connection. if GetSystemErrorCode(err) == ErrCodeProtocol { return _relayShouldRelease, r.conn.close(LogField{"reason", "RelayHost returned protocol error"}) } return _relayShouldRelease, nil } // Check that the current connection is in a valid state to handle a new call. if canHandle, state := r.canHandleNewCall(); !canHandle { call.Failed("relay-client-conn-inactive") call.End() err := errConnNotActive{"incoming", state} r.conn.SendSystemError(f.Header.ID, f.Span(), NewWrappedSystemError(ErrCodeDeclined, err)) return _relayShouldRelease, err } // Get a remote connection and check whether it can handle this call. remoteConn, ok, err := r.getDestination(f, call) if err == nil && ok { if canHandle, state := remoteConn.relay.canHandleNewCall(); !canHandle { err = NewWrappedSystemError(ErrCodeNetwork, errConnNotActive{"selected remote", state}) call.Failed("relay-remote-inactive") r.conn.SendSystemError(f.Header.ID, f.Span(), NewWrappedSystemError(ErrCodeDeclined, err)) } } if err != nil || !ok { // Failed to get a remote connection, or the connection is not in the right // state to handle this call. Since we already incremented pending on // the current relay, we need to decrement it. r.decrementPending() call.End() return _relayShouldRelease, err } origID := f.Header.ID destinationID := remoteConn.NextMessageID() ttl := f.TTL() if ttl > r.maxTimeout { ttl = r.maxTimeout f.SetTTL(r.maxTimeout) } span := f.Span() var mutatedChecksum Checksum if len(f.arg2Appends) > 0 { mutatedChecksum = f.checksumType.New() } // The remote side of the relay doesn't need to track stats or call state. remoteConn.relay.addRelayItem(false /* isOriginator */, destinationID, f.Header.ID, r, ttl, span, call, nil /* mutatedChecksum */) relayToDest := r.addRelayItem(true /* isOriginator */, f.Header.ID, destinationID, remoteConn.relay, ttl, span, call, mutatedChecksum) f.Header.ID = destinationID // If we have appends, the size of the frame to be relayed will change, potentially going // over the max frame size. Do a fragmenting send which is slightly more expensive but // will handle fragmenting if it is needed. if len(f.arg2Appends) > 0 { if err := r.fragmentingSend(call, f, relayToDest, origID); err != nil { r.failRelayItem(r.outbound, origID, _relayArg2ModifyFailed, err) r.logger.WithFields( LogField{"id", origID}, LogField{"err", err.Error()}, LogField{"caller", string(f.Caller())}, LogField{"dest", string(f.Service())}, LogField{"method", string(f.Method())}, ).Warn("Failed to send call with modified arg2.") } // fragmentingSend always sends new frames in place of the old frame so we must // release it separately return _relayShouldRelease, nil } call.SentBytes(f.Frame.Header.FrameSize()) sent, failure := relayToDest.destination.Receive(f.Frame, requestFrame) if !sent { r.failRelayItem(r.outbound, origID, failure, errFrameNotSent) return _relayShouldRelease, nil } return _relayNoRelease, nil } // Handle all frames except messageTypeCallReq. func (r *Relayer) handleNonCallReq(f *Frame) (shouldRelease bool, _ error) { frameType := frameTypeFor(f) finished := finishesCall(f) // If we read a request frame, we need to use the outbound map to decide // the destination. Otherwise, we use the inbound map. items := r.outbound if frameType == responseFrame { items = r.inbound } // Stop the timeout if the call if finished. item, stopped, ok := items.Get(f.Header.ID, finished /* stopTimeout */) if !ok { return _relayShouldRelease, errUnknownID } if item.tomb || (finished && !stopped) { // Item has previously timed out, or is in the process of timing out. // TODO: metrics for late-arriving frames. return _relayShouldRelease, nil } switch f.messageType() { case messageTypeCallRes: // Invoke call.CallResponse() if we get a valid call response frame. cr, err := newLazyCallRes(f) if err == nil { item.call.CallResponse(cr) } else { r.logger.WithFields( ErrField(err), LogField{"id", f.Header.ID}, ).Error("Malformed callRes frame.") } case messageTypeCallReqContinue: // Recalculate and update the checksum for this frame if it has non-nil item.mutatedChecksum // (meaning the call was mutated) and it is a callReqContinue frame. if item.mutatedChecksum != nil { r.updateMutatedCallReqContinueChecksum(f, item.mutatedChecksum) } } // Track sent/received bytes. We don't do this before we check // for timeouts, since this should only be called before call.End(). item.reportRelayBytes(frameType, f.Header.FrameSize()) originalID := f.Header.ID f.Header.ID = item.remapID sent, failure := item.destination.Receive(f, frameType) if !sent { r.failRelayItem(items, originalID, failure, errFrameNotSent) return _relayShouldRelease, nil } if finished { r.finishRelayItem(items, originalID) } return _relayNoRelease, nil } // addRelayItem adds a relay item to either outbound or inbound. func (r *Relayer) addRelayItem(isOriginator bool, id, remapID uint32, destination *Relayer, ttl time.Duration, span Span, call RelayCall, mutatedChecksum Checksum) relayItem { item := relayItem{ isOriginator: isOriginator, call: call, remapID: remapID, destination: destination, span: span, mutatedChecksum: mutatedChecksum, } items := r.inbound if isOriginator { items = r.outbound } item.timeout = r.timeouts.Get() items.Add(id, item) item.timeout.Start(ttl, items, id, isOriginator) return item } func (r *Relayer) timeoutRelayItem(items *relayItems, id uint32, isOriginator bool) { item, ok := items.Entomb(id, _relayTombTTL) if !ok { return } if isOriginator { r.conn.SendSystemError(id, item.span, ErrTimeout) item.call.Failed("timeout") item.call.End() } r.decrementPending() } // failRelayItem tombs the relay item so that future frames for this call are not // forwarded. We keep the relay item tombed, rather than delete it to ensure that // future frames do not cause error logs. func (r *Relayer) failRelayItem(items *relayItems, id uint32, reason string, err error) { // Stop the timeout, so we either fail it here, or in the timeout goroutine but not both. item, stopped, found := items.Get(id, true /* stopTimeout */) if !found { items.logger.WithFields(LogField{"id", id}).Warn("Attempted to fail non-existent relay item.") return } if !stopped { return } // Entomb it so that we don't get unknown exchange errors on further frames // for this call. item, ok := items.Entomb(id, _relayTombTTL) if !ok { return } if item.isOriginator { // If the client is too slow, then there's no point sending an error frame. if reason != _relayErrorSourceConnSlow { r.conn.SendSystemError(id, item.span, fmt.Errorf("%v: %v", reason, err)) } item.call.Failed(reason) item.call.End() } r.decrementPending() } func (r *Relayer) finishRelayItem(items *relayItems, id uint32) { item, ok := items.Delete(id) if !ok { return } if item.isOriginator { item.call.End() if item.mutatedChecksum != nil { item.mutatedChecksum.Release() } } r.decrementPending() } func (r *Relayer) decrementPending() { r.pending.Dec() r.conn.checkExchanges() } func (r *Relayer) canClose() bool { if r == nil { return true } return r.countPending() == 0 } func (r *Relayer) countPending() uint32 { return r.pending.Load() } func (r *Relayer) receiverItems(fType frameType) *relayItems { if fType == requestFrame { return r.inbound } return r.outbound } func (r *Relayer) handleLocalCallReq(cr *lazyCallReq) (shouldRelease bool) { // Check whether this is a service we want to handle locally. if _, ok := r.localHandler[string(cr.Service())]; !ok { return _relayNoRelease } f := cr.Frame // We can only handle non-fragmented calls in the relay channel. // This is a simplification to avoid back references from a mex to a // relayItem so that the relayItem is cleared when the call completes. if cr.HasMoreFragments() { r.logger.WithFields( LogField{"id", cr.Header.ID}, LogField{"source", string(cr.Caller())}, LogField{"dest", string(cr.Service())}, LogField{"method", string(cr.Method())}, ).Error("Received fragmented callReq intended for local relay channel, can only handle unfragmented calls.") r.conn.SendSystemError(f.Header.ID, cr.Span(), errRelayMethodFragmented) return _relayShouldRelease } if release := r.conn.handleFrameNoRelay(f); release { r.conn.opts.FramePool.Release(f) } return _relayShouldRelease } func (r *Relayer) fragmentingSend(call RelayCall, f *lazyCallReq, relayToDest relayItem, origID uint32) error { if f.isArg2Fragmented { return errFragmentedArg2WithAppend } if !bytes.Equal(f.as, _tchanThriftValueBytes) { return fmt.Errorf("%v: got %s", errArg2ThriftOnly, f.as) } cs := relayToDest.mutatedChecksum // TODO(echung): should we pool the writers? fragWriter := newFragmentingWriter( r.logger, r.newFragmentSender(relayToDest.destination, f, origID, call), cs, ) arg2Writer, err := fragWriter.ArgWriter(false /* last */) if err != nil { return fmt.Errorf("get arg2 writer: %v", err) } if err := writeArg2WithAppends(arg2Writer, f.arg2(), f.arg2Appends); err != nil { return fmt.Errorf("write arg2: %v", err) } if err := arg2Writer.Close(); err != nil { return fmt.Errorf("close arg2 writer: %v", err) } if err := NewArgWriter(fragWriter.ArgWriter(true /* last */)).Write(f.arg3()); err != nil { return errors.New("arg3 write failed") } return nil } func (r *Relayer) updateMutatedCallReqContinueChecksum(f *Frame, cs Checksum) { rbuf := typed.NewReadBuffer(f.SizedPayload()) rbuf.SkipBytes(1) // flags rbuf.SkipBytes(1) // checksum type: this should match the checksum type of the callReq frame checksumRef := typed.BytesRef(rbuf.ReadBytes(cs.Size())) // We only support non-fragmented arg2 for mutated calls, so by the time we hit callReqContinue both // arg1 and arg2 must already have been read. As the call would be finished when we've read all of // arg3, it isn't necessary to separately track its completion. // // In theory we could have a frame with 0-length arg3, which can happen if a manual flush occurred // after writing 0 bytes for arg3. This is handled correctly by // 1) reading n=0 (nArg3) // 2) reading 0 bytes from the rbuf // 3) updating the checksum with the current running checksum // // Additionally, if the checksum type results in a 0-length checksum, the .Update() would // become a copy between empty slices, which correctly becomes a noop. // TODO(cinchurge): include a test for len(arg3)==0 in the unit tests n := rbuf.ReadUint16() cs.Add(rbuf.ReadBytes(int(n))) checksumRef.Update(cs.Sum()) } func writeArg2WithAppends(w io.WriteCloser, arg2 []byte, appends []relay.KeyVal) (err error) { if len(arg2) < 2 { return errNoNHInArg2 } writer := typed.NewWriter(w) // nh:2 is the first two bytes of arg2, which should always be present nh := binary.BigEndian.Uint16(arg2[:2]) + uint16(len(appends)) writer.WriteUint16(nh) // arg2[2:] is the existing sequence of key/val pairs, which we can just copy // over verbatim if len(arg2) > 2 { writer.WriteBytes(arg2[2:]) } // append new key/val pairs to end of arg2 for _, kv := range appends { writer.WriteLen16Bytes(kv.Key) writer.WriteLen16Bytes(kv.Val) } return writer.Err() } func frameTypeFor(f *Frame) frameType { switch t := f.Header.messageType; t { case messageTypeCallRes, messageTypeCallResContinue, messageTypeError, messageTypePingRes: return responseFrame case messageTypeCallReq, messageTypeCallReqContinue, messageTypePingReq, messageTypeCancel: return requestFrame default: panic(fmt.Sprintf("unsupported frame type: %v", t)) } } func determinesCallSuccess(f *Frame) (succeeded bool, failMsg string) { switch f.messageType() { case messageTypeError: msg := newLazyError(f).Code().MetricsKey() return false, msg case messageTypeCancel: return false, "canceled" case messageTypeCallRes: if isCallResOK(f) { return true, "" } return false, "application-error" default: return false, "" } } func validateRelayMaxTimeout(d time.Duration, logger Logger) time.Duration { maxMillis := d / time.Millisecond if maxMillis > 0 && maxMillis <= math.MaxUint32 { return d } if d == 0 { return _defaultRelayMaxTimeout } logger.WithFields( LogField{"configuredMaxTimeout", d}, LogField{"defaultMaxTimeout", _defaultRelayMaxTimeout}, ).Warn("Configured RelayMaxTimeout is invalid, using default instead.") return _defaultRelayMaxTimeout } type sentBytesReporter interface { SentBytes(size uint16) } type relayFragmentSender struct { callReq *lazyCallReq framePool FramePool frameReceiver frameReceiver failRelayItemFunc func(items *relayItems, id uint32, failure string, err error) outboundRelayItems *relayItems origID uint32 sentReporter sentBytesReporter } func (r *Relayer) newFragmentSender(dstRelay frameReceiver, cr *lazyCallReq, origID uint32, sentReporter sentBytesReporter) *relayFragmentSender { // TODO(cinchurge): pool fragment senders return &relayFragmentSender{ callReq: cr, framePool: r.conn.opts.FramePool, frameReceiver: dstRelay, failRelayItemFunc: r.failRelayItem, outboundRelayItems: r.outbound, origID: origID, sentReporter: sentReporter, } } func (rfs *relayFragmentSender) newFragment(initial bool, checksum Checksum) (*writableFragment, error) { frame := rfs.framePool.Get() frame.Header.ID = rfs.callReq.Header.ID if initial { frame.Header.messageType = messageTypeCallReq } else { frame.Header.messageType = messageTypeCallReqContinue } contents := typed.NewWriteBuffer(frame.Payload[:]) // flags:1 // Flags MUST be copied over from the callReq frame to all new fragments since if there are more // fragments to follow the callReq, the destination needs to know about this or those frames will // be dropped from the call flagsRef := contents.DeferByte() flagsRef.Update(rfs.callReq.Payload[_flagsIndex]) if initial { // Copy all data before the checksum for the initial frame contents.WriteBytes(rfs.callReq.Payload[_flagsIndex+1 : rfs.callReq.checksumTypeOffset]) } // checksumType:1 contents.WriteSingleByte(byte(checksum.TypeCode())) // checksum: checksum.Size() checksumRef := contents.DeferBytes(checksum.Size()) if initial { // arg1~1: write arg1 to the initial frame contents.WriteUint16(uint16(len(rfs.callReq.method))) contents.WriteBytes(rfs.callReq.method) checksum.Add(rfs.callReq.method) } // TODO(cinchurge): pool writableFragment return &writableFragment{ flagsRef: flagsRef, checksumRef: checksumRef, // checksum will be released by the relayer when the call is finished checksum: &noReleaseChecksum{Checksum: checksum}, contents: contents, frame: frame, }, contents.Err() } func (rfs *relayFragmentSender) flushFragment(wf *writableFragment) error { wf.frame.Header.SetPayloadSize(uint16(wf.contents.BytesWritten())) rfs.sentReporter.SentBytes(wf.frame.Header.FrameSize()) sent, failure := rfs.frameReceiver.Receive(wf.frame, requestFrame) if !sent { rfs.failRelayItemFunc(rfs.outboundRelayItems, rfs.origID, failure, errFrameNotSent) rfs.framePool.Release(wf.frame) return nil } return nil } func (rfs *relayFragmentSender) doneSending() {} ================================================ FILE: relay_api.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "github.com/uber/tchannel-go/relay" // RelayHost is the interface used to create RelayCalls when the relay // receives an incoming call. type RelayHost interface { // SetChannels is called on creation of the channel. It's used to set a // channel reference which can be used to get references to *Peer. SetChannel(ch *Channel) // Start starts a new RelayCall given the call frame and connection. // It may return a call and an error, in which case the caller will // call Failed/End on the RelayCall. Start(relay.CallFrame, *relay.Conn) (RelayCall, error) } // RelayCall abstracts away peer selection, stats, and any other business // logic from the underlying relay implementation. A RelayCall may not // have a destination if there was an error during peer selection // (which should be returned from start). type RelayCall interface { // Destination returns the selected peer (if there was no error from Start). Destination() (peer *Peer, ok bool) // SentBytes is called when a frame is sent to the destination peer. SentBytes(uint16) // ReceivedBytes is called when a frame is received from the destination peer. ReceivedBytes(uint16) // CallResponse is called when a call response frame is received from the destination peer CallResponse(relay.RespFrame) // The call succeeded (possibly after retrying). Succeeded() // The call failed. Failed(reason string) // End stats collection for this RPC. Will be called exactly once. End() } ================================================ FILE: relay_benchmark_test.go ================================================ package tchannel_test import ( "fmt" "sync" "testing" "time" "github.com/bmizerany/perks/quantile" "github.com/stretchr/testify/require" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/benchmark" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/testutils" ) type benchmarkParams struct { servers, clients int requestSize int appends []relay.KeyVal } type workerControl struct { start sync.WaitGroup unblockStart chan struct{} done sync.WaitGroup } func init() { benchmark.BenchmarkDir = "./benchmark/" } func newWorkerControl(numWorkers int) *workerControl { wc := &workerControl{ unblockStart: make(chan struct{}), } wc.start.Add(numWorkers) wc.done.Add(numWorkers) return wc } func (c *workerControl) WaitForStart(f func()) { c.start.Wait() f() close(c.unblockStart) } func (c *workerControl) WaitForEnd() { c.done.Wait() } func (c *workerControl) WorkerStart() { c.start.Done() <-c.unblockStart } func (c *workerControl) WorkerDone() { c.done.Done() } func defaultParams() benchmarkParams { return benchmarkParams{ servers: 2, clients: 2, requestSize: 1024, } } func closeAndVerify(b *testing.B, ch *Channel) { ch.Close() isChanClosed := func() bool { return ch.State() == ChannelClosed } if !testutils.WaitFor(time.Second, isChanClosed) { b.Errorf("Timed out waiting for channel to close, state: %v", ch.State()) } } func benchmarkRelay(b *testing.B, p benchmarkParams) { b.SetBytes(int64(p.requestSize)) b.ReportAllocs() services := make(map[string][]string) servers := make([]benchmark.Server, p.servers) for i := range servers { servers[i] = benchmark.NewServer( benchmark.WithServiceName("svc"), benchmark.WithRequestSize(p.requestSize), benchmark.WithExternalProcess(), ) defer servers[i].Close() services["svc"] = append(services["svc]"], servers[i].HostPort()) } relay, err := benchmark.NewRealRelay(services, p.appends) require.NoError(b, err, "Failed to create relay") defer relay.Close() clients := make([]benchmark.Client, p.clients) for i := range clients { clients[i] = benchmark.NewClient([]string{relay.HostPort()}, benchmark.WithServiceName("svc"), benchmark.WithRequestSize(p.requestSize), benchmark.WithExternalProcess(), benchmark.WithTimeout(10*time.Second), ) defer clients[i].Close() require.NoError(b, clients[i].Warmup(), "Warmup failed") } quantileVals := []float64{0.50, 0.95, 0.99, 1.0} quantiles := make([]*quantile.Stream, p.clients) for i := range quantiles { quantiles[i] = quantile.NewTargeted(quantileVals...) } wc := newWorkerControl(p.clients) dec := testutils.Decrementor(b.N) var wg sync.WaitGroup errC := make(chan error, 1) defer close(errC) for i, c := range clients { wg.Add(1) go func(i int, c benchmark.Client) { defer wg.Done() // Do a warm up call. c.RawCall(1) wc.WorkerStart() defer wc.WorkerDone() for { tokens := dec.Multiple(200) if tokens == 0 { break } durations, err := c.RawCall(tokens) if err != nil { errC <- err return } for _, d := range durations { quantiles[i].Insert(float64(d)) } } }(i, c) } wg.Wait() if err := <-errC; err != nil { b.Fatalf("Call failed: %v", err) } var started time.Time wc.WaitForStart(func() { b.ResetTimer() started = time.Now() }) wc.WaitForEnd() duration := time.Since(started) fmt.Printf("\nb.N: %v Duration: %v RPS = %0.0f\n", b.N, duration, float64(b.N)/duration.Seconds()) // Merge all the quantiles into 1 for _, q := range quantiles[1:] { quantiles[0].Merge(q.Samples()) } for _, q := range quantileVals { fmt.Printf(" %0.4f = %v\n", q, time.Duration(quantiles[0].Query(q))) } fmt.Println() } func BenchmarkRelayNoLatencies(b *testing.B) { server := benchmark.NewServer( benchmark.WithServiceName("svc"), benchmark.WithExternalProcess(), benchmark.WithNoLibrary(), ) defer server.Close() hostMapping := map[string][]string{"svc": {server.HostPort()}} relay, err := benchmark.NewRealRelay(hostMapping, nil) require.NoError(b, err, "NewRealRelay failed") defer relay.Close() client := benchmark.NewClient([]string{relay.HostPort()}, benchmark.WithServiceName("svc"), benchmark.WithExternalProcess(), benchmark.WithNoLibrary(), benchmark.WithNumClients(10), benchmark.WithNoChecking(), benchmark.WithNoDurations(), benchmark.WithTimeout(10*time.Second), ) defer client.Close() require.NoError(b, client.Warmup(), "client.Warmup failed") b.ResetTimer() started := time.Now() for _, calls := range testutils.Batch(b.N, 10000) { if _, err := client.RawCall(calls); err != nil { b.Fatalf("Calls failed: %v", err) } } duration := time.Since(started) fmt.Printf("\nb.N: %v Duration: %v RPS = %0.0f\n", b.N, duration, float64(b.N)/duration.Seconds()) } func BenchmarkRelay2Servers5Clients1k(b *testing.B) { p := defaultParams() p.clients = 5 p.servers = 2 benchmarkRelay(b, p) } func BenchmarkRelay4Servers20Clients1k(b *testing.B) { p := defaultParams() p.clients = 20 p.servers = 4 benchmarkRelay(b, p) } func BenchmarkRelay2Servers5Clients4k(b *testing.B) { p := defaultParams() p.requestSize = 4 * 1024 p.clients = 5 p.servers = 2 benchmarkRelay(b, p) } func BenchmarkRelayAppends(b *testing.B) { for _, n := range []int{0, 1, 2, 5, 10} { b.Run(fmt.Sprintf("%v appends", n), func(b *testing.B) { p := defaultParams() for i := 0; i < n; i++ { p.appends = append(p.appends, relay.KeyVal{Key: []byte("foo"), Val: []byte("bar")}) } b.ResetTimer() benchmarkRelay(b, p) }) } } ================================================ FILE: relay_fragment_sender_test.go ================================================ package tchannel import ( "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/testutils/thriftarg2test" ) var _ frameReceiver = (*dummyFrameReceiver)(nil) type dummyFrameReceiver struct { retSent bool retFailureReason string pool FramePool // mutable gotPayload []byte } func newDummyFrameReceiver(retSent bool, retFailureReason string, pool FramePool) *dummyFrameReceiver { return &dummyFrameReceiver{ retSent: retSent, retFailureReason: retFailureReason, pool: pool, } } func (d *dummyFrameReceiver) Receive(f *Frame, fType frameType) (sent bool, failureReason string) { // Keep a record of the received payload for verification d.gotPayload = make([]byte, len(f.SizedPayload())) copy(d.gotPayload, f.SizedPayload()) // Frames should be released after transmission if d.retSent { d.pool.Release(f) } return d.retSent, d.retFailureReason } type noopSentReporter struct{} func (r *noopSentReporter) SentBytes(_ uint16) {} func TestRelayFragmentSender(t *testing.T) { tests := []struct { msg string sent bool failure string wantFailureRelayItemFuncCalled bool wantPayload []byte }{ { msg: "successful send", sent: true, wantPayload: []byte("hello, world"), }, { msg: "send failure", sent: false, failure: "something bad happened", wantFailureRelayItemFuncCalled: true, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { var failRelayItemFuncCalled bool pool := NewCheckedFramePoolForTest() defer func() { CheckFramePoolIsEmpty(t, pool) }() cr := reqHasAll.req(t) receiver := newDummyFrameReceiver(tt.sent, tt.failure, pool) rfs := relayFragmentSender{ callReq: &cr, framePool: pool, frameReceiver: receiver, failRelayItemFunc: func(items *relayItems, id uint32, failure string, err error) { failRelayItemFuncCalled = true assert.Equal(t, uint32(123), id, "got unexpected id") assert.Equal(t, tt.failure, failure, "got unexpected failure string") assert.Error(t, err, "missing err") }, origID: 123, sentReporter: &noopSentReporter{}, } wf, err := rfs.newFragment(true, nullChecksum{}) require.NoError(t, err) // Get the payload expected by receive before the fragment is released wantPayload := make([]byte, wf.contents.BytesWritten()) copy(wantPayload, wf.frame.Payload[:wf.contents.BytesWritten()]) err = rfs.flushFragment(wf) require.NoError(t, err) assert.Equal(t, wantPayload, receiver.gotPayload) assert.Equal(t, tt.wantFailureRelayItemFuncCalled, failRelayItemFuncCalled, "unexpected failRelayItemFunc called state") }) } } type dummyArgWriter struct { numCall int writeError []string closeError string bytesWritten []byte } func (w *dummyArgWriter) Write(b []byte) (int, error) { retErr := w.writeError[w.numCall] w.bytesWritten = append(w.bytesWritten, b...) w.numCall++ if retErr != "" { return 0, errors.New(retErr) } return len(b), nil } func (w *dummyArgWriter) Close() error { if w.closeError != "" { return errors.New(w.closeError) } return nil } func TestWriteArg2WithAppends(t *testing.T) { tests := []struct { msg string writer *dummyArgWriter arg2Map map[string]string overrideArg2Buf []byte appends []relay.KeyVal wantError string }{ { msg: "write success without appends", writer: &dummyArgWriter{ writeError: []string{ "", // nh "", // arg2 }, }, arg2Map: exampleArg2Map, }, { msg: "write success with appends", writer: &dummyArgWriter{ writeError: []string{ "", // nh "", // arg2 "", // key length "", // key "", // val length "", // val }, }, arg2Map: exampleArg2Map, appends: []relay.KeyVal{ {Key: []byte("foo"), Val: []byte("bar")}, }, }, { msg: "no nh in data", writer: &dummyArgWriter{ writeError: []string{ assert.AnError.Error(), // nh }, }, overrideArg2Buf: []byte{0}, wantError: "no nh in arg2", }, { msg: "write nh fails", writer: &dummyArgWriter{ writeError: []string{ assert.AnError.Error(), // nh }, }, arg2Map: exampleArg2Map, wantError: assert.AnError.Error(), }, { msg: "write arg2 fails", writer: &dummyArgWriter{ writeError: []string{ "", // write nh assert.AnError.Error(), // write arg2 }, }, arg2Map: exampleArg2Map, wantError: assert.AnError.Error(), }, { msg: "write append key length fails", writer: &dummyArgWriter{ writeError: []string{ "", // write nh "", // write arg2 assert.AnError.Error(), // write key length }, }, arg2Map: exampleArg2Map, appends: []relay.KeyVal{ {Key: []byte("foo"), Val: []byte("bar")}, }, wantError: assert.AnError.Error(), }, { msg: "write append key fails", writer: &dummyArgWriter{ writeError: []string{ "", // write nh "", // write arg2 "", // write key length assert.AnError.Error(), // write key }, }, arg2Map: exampleArg2Map, appends: []relay.KeyVal{ {Key: []byte("foo"), Val: []byte("bar")}, }, wantError: assert.AnError.Error(), }, { msg: "write append val length fails", writer: &dummyArgWriter{ writeError: []string{ "", // write nh "", // write arg2 "", // write key length "", // write key assert.AnError.Error(), // write val length }, }, arg2Map: exampleArg2Map, appends: []relay.KeyVal{ {Key: []byte("foo"), Val: []byte("bar")}, }, wantError: assert.AnError.Error(), }, { msg: "write append val fails", writer: &dummyArgWriter{ writeError: []string{ "", // write nh "", // write arg2 "", // write key length "", // write key "", // write val length assert.AnError.Error(), // write val }, }, arg2Map: exampleArg2Map, appends: []relay.KeyVal{ {Key: []byte("foo"), Val: []byte("bar")}, }, wantError: assert.AnError.Error(), }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { var arg2buf []byte if tt.overrideArg2Buf != nil { arg2buf = tt.overrideArg2Buf } else if len(tt.arg2Map) > 0 { arg2buf = thriftarg2test.BuildKVBuffer(tt.arg2Map) } err := writeArg2WithAppends(tt.writer, arg2buf, tt.appends) if tt.wantError != "" { require.EqualError(t, err, tt.wantError) return } require.NoError(t, tt.writer.Close()) finalMap := make(map[string]string) for k, v := range tt.arg2Map { finalMap[k] = v } for _, kv := range tt.appends { finalMap[string(kv.Key)] = string(kv.Val) } require.NoError(t, err) assert.Equal(t, finalMap, thriftarg2test.MustReadKVBuffer(t, tt.writer.bytesWritten)) }) } } ================================================ FILE: relay_internal_test.go ================================================ package tchannel import ( "testing" "time" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" ) func TestFinishesCallResponses(t *testing.T) { tests := []struct { msgType messageType flags byte finishesCall bool }{ {messageTypeCallRes, 0x00, true}, {messageTypeCallRes, 0x01, false}, {messageTypeCallRes, 0x02, true}, {messageTypeCallRes, 0x03, false}, {messageTypeCallRes, 0x04, true}, {messageTypeCallResContinue, 0x00, true}, {messageTypeCallResContinue, 0x01, false}, {messageTypeCallResContinue, 0x02, true}, {messageTypeCallResContinue, 0x03, false}, {messageTypeCallResContinue, 0x04, true}, // By definition, callreq should never terminate an RPC. {messageTypeCallReq, 0x00, false}, {messageTypeCallReq, 0x01, false}, {messageTypeCallReq, 0x02, false}, {messageTypeCallReq, 0x03, false}, {messageTypeCallReq, 0x04, false}, } for _, tt := range tests { f := NewFrame(100) fh := FrameHeader{ size: uint16(0xFF34), messageType: tt.msgType, ID: 0xDEADBEEF, } f.Header = fh fh.write(typed.NewWriteBuffer(f.headerBuffer)) payload := typed.NewWriteBuffer(f.Payload) payload.WriteSingleByte(tt.flags) assert.Equal(t, tt.finishesCall, finishesCall(f), "Wrong isLast for flags %v and message type %v", tt.flags, tt.msgType) } } func TestRelayTimerPoolMisuse(t *testing.T) { tests := []struct { msg string f func(*relayTimer) }{ { msg: "release without stop", f: func(rt *relayTimer) { rt.Start(time.Hour, &relayItems{}, 0, false /* isOriginator */) rt.Release() }, }, { msg: "start twice", f: func(rt *relayTimer) { rt.Start(time.Hour, &relayItems{}, 0, false /* isOriginator */) rt.Start(time.Hour, &relayItems{}, 0, false /* isOriginator */) }, }, { msg: "underlying timer is already active", f: func(rt *relayTimer) { rt.timer.Reset(time.Hour) rt.Start(time.Hour, &relayItems{}, 0, false /* isOriginator */) }, }, { msg: "use timer after releasing it", f: func(rt *relayTimer) { rt.Release() rt.Stop() }, }, } for _, tt := range tests { trigger := func(*relayItems, uint32, bool) {} rtp := newRelayTimerPool(trigger, true /* verify */) rt := rtp.Get() assert.Panics(t, func() { tt.f(rt) }, tt.msg) } } ================================================ FILE: relay_messages.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "bytes" "encoding/binary" "fmt" "time" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/thrift/arg2" "github.com/uber/tchannel-go/typed" ) var _ relay.RespFrame = (*lazyCallRes)(nil) var ( _callerNameKeyBytes = []byte(CallerName) _routingDelegateKeyBytes = []byte(RoutingDelegate) _routingKeyKeyBytes = []byte(RoutingKey) _argSchemeKeyBytes = []byte(ArgScheme) _tchanThriftValueBytes = []byte(Thrift) ) const ( // Common to many frame types. _flagsIndex = 0 // For call req, indexes into the frame. // Use int for indexes to avoid overflow caused by accidental byte arithmentic. _ttlIndex int = 1 _ttlLen int = 4 _spanIndex int = _ttlIndex + _ttlLen _spanLength int = 25 _serviceLenIndex int = _spanIndex + _spanLength _serviceNameIndex int = _serviceLenIndex + 1 // For call res and call res continue. _resCodeOK = 0x00 _resCodeIndex int = 1 // For error. _errCodeIndex int = 0 ) type lazyError struct { *Frame } func newLazyError(f *Frame) lazyError { if msgType := f.Header.messageType; msgType != messageTypeError { panic(fmt.Errorf("newLazyError called for wrong messageType: %v", msgType)) } return lazyError{f} } func (e lazyError) Code() SystemErrCode { return SystemErrCode(e.Payload[_errCodeIndex]) } type lazyCallRes struct { *Frame as []byte arg2IsFragmented bool arg2Payload []byte } func newLazyCallRes(f *Frame) (lazyCallRes, error) { if msgType := f.Header.messageType; msgType != messageTypeCallRes { panic(fmt.Errorf("newLazyCallRes called for wrong messageType: %v", msgType)) } rbuf := typed.NewReadBuffer(f.SizedPayload()) rbuf.SkipBytes(1) // flags rbuf.SkipBytes(1) // code rbuf.SkipBytes(_spanLength) // tracing var as []byte nh := int(rbuf.ReadSingleByte()) for i := 0; i < nh; i++ { keyLen := int(rbuf.ReadSingleByte()) key := rbuf.ReadBytes(keyLen) valLen := int(rbuf.ReadSingleByte()) val := rbuf.ReadBytes(valLen) if bytes.Equal(key, _argSchemeKeyBytes) { as = val continue } } csumtype := ChecksumType(rbuf.ReadSingleByte()) // csumtype rbuf.SkipBytes(csumtype.ChecksumSize()) // csum // arg1: ignored narg1 := int(rbuf.ReadUint16()) rbuf.SkipBytes(narg1) // arg2: keep track of payload narg2 := int(rbuf.ReadUint16()) arg2Payload := rbuf.ReadBytes(narg2) arg2IsFragmented := rbuf.BytesRemaining() == 0 && hasMoreFragments(f) // arg3: ignored // Make sure we didn't hit any issues reading the buffer if err := rbuf.Err(); err != nil { return lazyCallRes{}, fmt.Errorf("read response frame: %v", err) } return lazyCallRes{ Frame: f, as: as, arg2IsFragmented: arg2IsFragmented, arg2Payload: arg2Payload, }, nil } // OK implements relay.RespFrame func (cr lazyCallRes) OK() bool { return isCallResOK(cr.Frame) } // ArgScheme implements relay.RespFrame func (cr lazyCallRes) ArgScheme() []byte { return cr.as } // Arg2IsFragmented implements relay.RespFrame func (cr lazyCallRes) Arg2IsFragmented() bool { return cr.arg2IsFragmented } // Arg2 implements relay.RespFrame func (cr lazyCallRes) Arg2() []byte { return cr.arg2Payload } type lazyCallReq struct { *Frame checksumTypeOffset uint16 arg2StartOffset, arg2EndOffset uint16 arg3StartOffset uint16 caller, method, delegate, key, as []byte arg2Appends []relay.KeyVal checksumType ChecksumType isArg2Fragmented bool // Intentionally an array to combine allocations with that of lazyCallReq arg2InitialBuf [1]relay.KeyVal } // TODO: Consider pooling lazyCallReq and using pointers to the struct. func newLazyCallReq(f *Frame) (*lazyCallReq, error) { if msgType := f.Header.messageType; msgType != messageTypeCallReq { panic(fmt.Errorf("newLazyCallReq called for wrong messageType: %v", msgType)) } cr := &lazyCallReq{Frame: f} cr.arg2Appends = cr.arg2InitialBuf[:0] rbuf := typed.NewReadBuffer(f.SizedPayload()) rbuf.SkipBytes(_serviceLenIndex) // service~1 serviceLen := rbuf.ReadSingleByte() rbuf.SkipBytes(int(serviceLen)) // nh:1 (hk~1 hv~1){nh} numHeaders := int(rbuf.ReadSingleByte()) for i := 0; i < numHeaders; i++ { keyLen := int(rbuf.ReadSingleByte()) key := rbuf.ReadBytes(keyLen) valLen := int(rbuf.ReadSingleByte()) val := rbuf.ReadBytes(valLen) if bytes.Equal(key, _argSchemeKeyBytes) { cr.as = val } else if bytes.Equal(key, _callerNameKeyBytes) { cr.caller = val } else if bytes.Equal(key, _routingDelegateKeyBytes) { cr.delegate = val } else if bytes.Equal(key, _routingKeyKeyBytes) { cr.key = val } } // csumtype:1 (csum:4){0,1} arg1~2 arg2~2 arg3~2 cr.checksumTypeOffset = uint16(rbuf.BytesRead()) cr.checksumType = ChecksumType(rbuf.ReadSingleByte()) rbuf.SkipBytes(cr.checksumType.ChecksumSize()) // arg1~2 arg1Len := int(rbuf.ReadUint16()) cr.method = rbuf.ReadBytes(arg1Len) // arg2~2 arg2Len := rbuf.ReadUint16() cr.arg2StartOffset = uint16(rbuf.BytesRead()) cr.arg2EndOffset = cr.arg2StartOffset + arg2Len // arg2 is fragmented if we don't see arg3 in this frame. rbuf.SkipBytes(int(arg2Len)) cr.isArg2Fragmented = rbuf.BytesRemaining() == 0 && cr.HasMoreFragments() if !cr.isArg2Fragmented { // arg3~2 rbuf.SkipBytes(2) cr.arg3StartOffset = uint16(rbuf.BytesRead()) } if rbuf.Err() != nil { return nil, rbuf.Err() } return cr, nil } // Caller returns the name of the originator of this callReq. func (f *lazyCallReq) Caller() []byte { return f.caller } // Service returns the name of the destination service for this callReq. func (f *lazyCallReq) Service() []byte { l := f.Payload[_serviceLenIndex] return f.Payload[_serviceNameIndex : _serviceNameIndex+int(l)] } // Method returns the name of the method being called. func (f *lazyCallReq) Method() []byte { return f.method } // RoutingDelegate returns the routing delegate for this call req, if any. func (f *lazyCallReq) RoutingDelegate() []byte { return f.delegate } // RoutingKey returns the routing delegate for this call req, if any. func (f *lazyCallReq) RoutingKey() []byte { return f.key } // TTL returns the time to live for this callReq. func (f *lazyCallReq) TTL() time.Duration { ttl := binary.BigEndian.Uint32(f.Payload[_ttlIndex : _ttlIndex+_ttlLen]) return time.Duration(ttl) * time.Millisecond } // SetTTL overwrites the frame's TTL. func (f *lazyCallReq) SetTTL(d time.Duration) { ttl := uint32(d / time.Millisecond) binary.BigEndian.PutUint32(f.Payload[_ttlIndex:_ttlIndex+_ttlLen], ttl) } // Span returns the Span func (f *lazyCallReq) Span() Span { return callReqSpan(f.Frame) } // HasMoreFragments returns whether the callReq has more fragments. func (f *lazyCallReq) HasMoreFragments() bool { return f.Payload[_flagsIndex]&hasMoreFragmentsFlag != 0 } // Arg2EndOffset returns the offset from start of payload to the end of Arg2 // in bytes, and hasMore to be true if there are more frames and arg3 has // not started. func (f *lazyCallReq) Arg2EndOffset() (_ int, hasMore bool) { return int(f.arg2EndOffset), f.isArg2Fragmented } // Arg2StartOffset returns the offset from start of payload to the beginning // of Arg2 in bytes. func (f *lazyCallReq) Arg2StartOffset() int { return int(f.arg2StartOffset) } func (f *lazyCallReq) arg2() []byte { return f.Payload[f.arg2StartOffset:f.arg2EndOffset] } func (f *lazyCallReq) arg3() []byte { return f.SizedPayload()[f.arg3StartOffset:] } // Arg2Iterator returns the iterator for reading Arg2 key value pair // of TChannel-Thrift Arg Scheme. func (f *lazyCallReq) Arg2Iterator() (arg2.KeyValIterator, error) { if !bytes.Equal(f.as, _tchanThriftValueBytes) { return arg2.KeyValIterator{}, fmt.Errorf("%v: got %s", errArg2ThriftOnly, f.as) } return arg2.NewKeyValIterator(f.Payload[f.arg2StartOffset:f.arg2EndOffset]) } func (f *lazyCallReq) Arg2Append(key, val []byte) { f.arg2Appends = append(f.arg2Appends, relay.KeyVal{Key: key, Val: val}) } // finishesCall checks whether this frame is the last one we should expect for // this RPC req-res. func finishesCall(f *Frame) bool { switch f.messageType() { case messageTypeError, messageTypeCancel: return true case messageTypeCallRes, messageTypeCallResContinue: flags := f.Payload[_flagsIndex] return flags&hasMoreFragmentsFlag == 0 default: return false } } // isCallResOK indicates whether the call was successful func isCallResOK(f *Frame) bool { return f.Payload[_resCodeIndex] == _resCodeOK } // hasMoreFragments indicates whether there are more fragments following this frame func hasMoreFragments(f *Frame) bool { return f.Payload[_flagsIndex]&hasMoreFragmentsFlag != 0 } ================================================ FILE: relay_messages_benchmark_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "io/ioutil" "testing" ) func BenchmarkCallReqFrame(b *testing.B) { cr := reqHasAll.req(b) f := cr.Frame var service, caller, method []byte b.ResetTimer() for i := 0; i < b.N; i++ { cr, err := newLazyCallReq(f) if err != nil { b.Fatal("Unexpected error") } // Multiple calls due to peer selection, stats, etc. for i := 0; i < 3; i++ { service = cr.Service() caller = cr.Caller() method = cr.Method() } } b.StopTimer() fmt.Fprint(ioutil.Discard, service, caller, method) } ================================================ FILE: relay_messages_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "io" "math" "strconv" "strings" "testing" "time" "github.com/uber/tchannel-go/thrift/arg2" "github.com/uber/tchannel-go/testutils/thriftarg2test" "github.com/uber/tchannel-go/typed" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testCallReq int const ( reqHasHeaders testCallReq = (1 << iota) reqHasCaller reqHasDelegate reqHasRoutingKey reqHasChecksum reqTotalCombinations reqHasAll testCallReq = reqTotalCombinations - 1 ) var ( exampleArg2Map = map[string]string{ "foo": "bar", "baz": "qux", } ) const ( exampleService = "fooservice" exampleArg3Data = "some arg3 data" ) type testCallReqParams struct { flags byte hasTChanThrift bool argScheme Format arg2Buf []byte overrideArg2Len int skipArg3 bool arg3Buf []byte serviceOverride string } func (cr testCallReq) req(tb testing.TB) lazyCallReq { return cr.reqWithParams(tb, testCallReqParams{}) } func (cr testCallReq) reqWithParams(tb testing.TB, p testCallReqParams) lazyCallReq { lcr, err := newLazyCallReq(cr.frameWithParams(tb, p)) require.NoError(tb, err) return *lcr } func (cr testCallReq) frameWithParams(t testing.TB, p testCallReqParams) *Frame { // TODO: Constructing a frame is ugly because the initial flags byte is // written in reqResWriter instead of callReq. We should instead handle that // in callReq, which will allow our tests to be sane. f := NewFrame(MaxFramePayloadSize) fh := fakeHeader(messageTypeCallReq) // Set the size in the header and write out the header after we know the payload contents. defer func() { fh.size = FrameHeaderSize + uint16(len(f.Payload)) f.Header = fh require.NoError(t, fh.write(typed.NewWriteBuffer(f.headerBuffer)), "failed to write header") }() payload := typed.NewWriteBuffer(f.Payload) payload.WriteSingleByte(p.flags) // flags payload.WriteUint32(42) // TTL payload.WriteBytes(make([]byte, 25)) // tracing svc := p.serviceOverride if svc == "" { svc = exampleService } payload.WriteLen8String(svc) // service headers := make(map[string]string) switch p.argScheme { case HTTP, JSON, Raw, Thrift: headers["as"] = p.argScheme.String() } if cr&reqHasHeaders != 0 { addRandomHeaders(headers) } if cr&reqHasCaller != 0 { headers["cn"] = "fake-caller" } if cr&reqHasDelegate != 0 { headers["rd"] = "fake-delegate" } if cr&reqHasRoutingKey != 0 { headers["rk"] = "fake-routingkey" } writeHeaders(payload, headers) if cr&reqHasChecksum == 0 { payload.WriteSingleByte(byte(ChecksumTypeNone)) // checksum type // no checksum contents for None } else { payload.WriteSingleByte(byte(ChecksumTypeCrc32C)) // checksum type payload.WriteUint32(0) // checksum contents } payload.WriteLen16String("moneys") // method arg2Len := len(p.arg2Buf) if p.overrideArg2Len > 0 { arg2Len = p.overrideArg2Len } payload.WriteUint16(uint16(arg2Len)) payload.WriteBytes(p.arg2Buf) if !p.skipArg3 { arg3Len := len(p.arg3Buf) payload.WriteUint16(uint16(arg3Len)) payload.WriteBytes(p.arg3Buf) } f.Payload = f.Payload[:payload.BytesWritten()] require.NoError(t, payload.Err(), "failed to write payload") return f } func withLazyCallReqCombinations(f func(cr testCallReq)) { for cr := testCallReq(0); cr < reqTotalCombinations; cr++ { f(cr) } } type testCallRes int type testCallResParams struct { hasFragmentedArg2 bool flags byte code byte span [25]byte isThrift bool headers map[string]string csumType byte arg1 []byte arg2Prefix []byte // used for corrupting arg2 arg2KeyVals map[string]string arg3 []byte } const ( resIsContinued testCallRes = (1 << iota) resIsOK resHasHeaders resHasChecksum resIsThrift resHasArg2 resHasFragmentedArg2 resTotalCombinations ) func (cr testCallRes) res(tb testing.TB) lazyCallRes { var params testCallResParams if cr&resHasFragmentedArg2 != 0 { params.hasFragmentedArg2 = true } if cr&(resIsContinued|resHasFragmentedArg2) != 0 { params.flags |= hasMoreFragmentsFlag } if cr&resIsOK == 0 { params.code = 1 } params.headers = map[string]string{} if cr&resHasHeaders != 0 { params.headers["k1"] = "v1" params.headers["k222222"] = "" params.headers["k3"] = "thisisalonglongkey" } if cr&(resIsThrift) != 0 { params.isThrift = true params.headers[string(_argSchemeKeyBytes)] = string(_tchanThriftValueBytes) } if cr&resHasChecksum != 0 { params.csumType = byte(ChecksumTypeCrc32C) } if cr&resHasArg2 != 0 { params.arg2KeyVals = exampleArg2Map } lcr, err := newLazyCallRes(newCallResFrame(tb, params)) require.NoError(tb, err, "Unexpected error creating lazyCallRes") return lcr } func withLazyCallResCombinations(t *testing.T, f func(t *testing.T, cr testCallRes)) { for cr := testCallRes(0); cr < resTotalCombinations; cr++ { t.Run(fmt.Sprintf("cr=%v", strconv.FormatInt(int64(cr), 2)), func(t *testing.T) { f(t, cr) }) } } func newCallResFrame(tb testing.TB, p testCallResParams) *Frame { f := NewFrame(MaxFramePayloadSize) fh := fakeHeader(messageTypeCallRes) payload := typed.NewWriteBuffer(f.Payload) defer func() { fh.SetPayloadSize(uint16(payload.BytesWritten())) f.Header = fh require.NoError(tb, fh.write(typed.NewWriteBuffer(f.headerBuffer)), "Failed to write header") }() payload.WriteSingleByte(p.flags) // flags payload.WriteSingleByte(p.code) // code payload.WriteBytes(p.span[:]) // span payload.WriteSingleByte(byte(len(p.headers))) // headers for k, v := range p.headers { payload.WriteSingleByte(byte(len(k))) payload.WriteBytes([]byte(k)) payload.WriteSingleByte(byte(len(v))) payload.WriteBytes([]byte(v)) } payload.WriteSingleByte(p.csumType) // checksum type payload.WriteBytes(make([]byte, ChecksumType(p.csumType).ChecksumSize())) // dummy checksum (not used in tests) // arg1 payload.WriteUint16(uint16(len(p.arg1))) payload.WriteBytes(p.arg1) require.NoError(tb, payload.Err(), "Got unexpected error constructing callRes frame") // arg2 payload.WriteBytes(p.arg2Prefix) // prefix is used only for corrupting arg2 arg2SizeRef := payload.DeferUint16() arg2StartBytes := payload.BytesWritten() if p.isThrift { arg2NHRef := payload.DeferUint16() var arg2NH uint16 for k, v := range p.arg2KeyVals { arg2NH++ payload.WriteLen16String(k) payload.WriteLen16String(v) } if p.hasFragmentedArg2 { // fill remainder of frame with the next key/val arg2NH++ payload.WriteLen16String("ube") payload.WriteLen16String(strings.Repeat("r", payload.BytesRemaining()-2)) } arg2NHRef.Update(arg2NH) } else { for k, v := range p.arg2KeyVals { payload.WriteString(k + v) } if p.hasFragmentedArg2 { payload.WriteString("ube" + strings.Repeat("r", payload.BytesRemaining()-3)) } } require.NoError(tb, payload.Err(), "Got unexpected error constructing callRes frame") arg2SizeRef.Update(uint16(payload.BytesWritten() - arg2StartBytes)) if !p.hasFragmentedArg2 { // arg3 payload.WriteUint16(uint16(len(p.arg3))) payload.WriteBytes(p.arg3) } require.NoError(tb, payload.Err(), "Got unexpected error constructing callRes frame") return f } func (ec SystemErrCode) fakeErrFrame() lazyError { f := NewFrame(100) fh := FrameHeader{ size: uint16(0xFF34), messageType: messageTypeError, ID: invalidMessageID, } f.Header = fh fh.write(typed.NewWriteBuffer(f.headerBuffer)) payload := typed.NewWriteBuffer(f.Payload) payload.WriteSingleByte(byte(ec)) payload.WriteBytes(make([]byte, 25)) // tracing msg := ec.String() payload.WriteUint16(uint16(len(msg))) payload.WriteBytes([]byte(msg)) return newLazyError(f) } func withLazyErrorCombinations(f func(ec SystemErrCode)) { codes := []SystemErrCode{ ErrCodeInvalid, ErrCodeTimeout, ErrCodeCancelled, ErrCodeBusy, ErrCodeDeclined, ErrCodeUnexpected, ErrCodeBadRequest, ErrCodeNetwork, ErrCodeProtocol, } for _, ec := range codes { f(ec) } } func addRandomHeaders(headers map[string]string) { headers["k1"] = "v1" headers["k222222"] = "" headers["k3"] = "thisisalonglongkey" } func writeHeaders(w *typed.WriteBuffer, headers map[string]string) { w.WriteSingleByte(byte(len(headers))) // number of headers for k, v := range headers { w.WriteLen8String(k) w.WriteLen8String(v) } } func assertWrappingPanics(t testing.TB, f *Frame, wrap func(f *Frame)) { assert.Panics(t, func() { wrap(f) }, "Should panic when wrapping an unexpected frame type.") } func TestLazyCallReqRejectsOtherFrames(t *testing.T) { assertWrappingPanics( t, resIsContinued.res(t).Frame, func(f *Frame) { newLazyCallReq(f) }, ) } func TestLazyCallReqService(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) assert.Equal(t, exampleService, string(cr.Service()), "Service name mismatch") }) } func TestLazyCallReqCaller(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) if crt&reqHasCaller == 0 { assert.Equal(t, []byte(nil), cr.Caller(), "Unexpected caller name.") } else { assert.Equal(t, "fake-caller", string(cr.Caller()), "Caller name mismatch") } }) } func TestLazyCallReqRoutingDelegate(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) if crt&reqHasDelegate == 0 { assert.Equal(t, []byte(nil), cr.RoutingDelegate(), "Unexpected routing delegate.") } else { assert.Equal(t, "fake-delegate", string(cr.RoutingDelegate()), "Routing delegate mismatch.") } }) } func TestLazyCallReqRoutingKey(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) if crt&reqHasRoutingKey == 0 { assert.Equal(t, []byte(nil), cr.RoutingKey(), "Unexpected routing key.") } else { assert.Equal(t, "fake-routingkey", string(cr.RoutingKey()), "Routing key mismatch.") } }) } func TestLazyCallReqMethod(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) assert.Equal(t, "moneys", string(cr.Method()), "Method name mismatch") }) } func TestLazyCallReqTTL(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) assert.Equal(t, 42*time.Millisecond, cr.TTL(), "Failed to parse TTL from frame.") }) } func TestLazyCallReqSetTTL(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.req(t) cr.SetTTL(time.Second) assert.Equal(t, time.Second, cr.TTL(), "Failed to write TTL to frame.") }) } func TestLazyCallArg2Offset(t *testing.T) { wantArg2Buf := []byte("test arg2 buf") tests := []struct { msg string flags byte arg2Buf []byte }{ { msg: "arg2 is fully contained in frame", arg2Buf: wantArg2Buf, }, { msg: "has no arg2", }, { msg: "frame fragmented but arg2 is fully contained", flags: hasMoreFragmentsFlag, arg2Buf: wantArg2Buf, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { cr := crt.reqWithParams(t, testCallReqParams{ flags: tt.flags, arg2Buf: tt.arg2Buf, }) arg2EndOffset, hasMore := cr.Arg2EndOffset() assert.False(t, hasMore) if len(tt.arg2Buf) == 0 { assert.Zero(t, arg2EndOffset-cr.Arg2StartOffset()) return } arg2Payload := cr.Payload[cr.Arg2StartOffset():arg2EndOffset] assert.Equal(t, tt.arg2Buf, arg2Payload) }) }) } t.Run("no arg3 set", func(t *testing.T) { tests := []struct { msg string hasMore bool wantError string }{ { msg: "hasMore flag=true", hasMore: true, }, { msg: "hasMore flag=false", wantError: "buffer is too small", }, } for _, tt := range tests { t.Run(fmt.Sprintf(tt.msg), func(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { // For each CallReq, we first get the remaining space left, and // fill up the remaining space with arg2. crNoArg2 := crt.req(t) arg2Size := MaxFramePayloadSize - crNoArg2.Arg2StartOffset() var flags byte if tt.hasMore { flags |= hasMoreFragmentsFlag } f := crt.frameWithParams(t, testCallReqParams{ flags: flags, arg2Buf: make([]byte, arg2Size), skipArg3: true, }) cr, err := newLazyCallReq(f) if tt.wantError != "" { require.EqualError(t, err, tt.wantError) return } require.NoError(t, err) endOffset, hasMore := cr.Arg2EndOffset() assert.Equal(t, hasMore, tt.hasMore) assert.EqualValues(t, MaxFramePayloadSize, endOffset) }) }) } }) } func TestLazyCallReqSetTChanThriftArg2(t *testing.T) { tests := []struct { msg string bufKV map[string]string wantKV map[string]string // if not set, use bufKV argScheme Format overrideBufLen int wantBadErr string }{ { msg: "two key value pairs", argScheme: Thrift, bufKV: map[string]string{ "key": "val", "key2": "val2", }, }, { msg: "length not enough to cover key len", argScheme: Thrift, bufKV: map[string]string{ "key": "val", }, overrideBufLen: 3, // 2 (nh) + 2 - 1 wantBadErr: "buffer is too small", }, { msg: "length not enough to cover key", argScheme: Thrift, bufKV: map[string]string{ "key": "val", }, overrideBufLen: 6, // 2 (nh) + 2 + len(key) - 1 wantBadErr: "buffer is too small", }, { msg: "length not enough to cover value len", argScheme: Thrift, bufKV: map[string]string{ "key": "val", }, overrideBufLen: 8, // 2 (nh) + 2 + len(key) + 2 - 1 wantBadErr: "buffer is too small", }, { msg: "length not enough to cover value", argScheme: Thrift, bufKV: map[string]string{ "key": "val", }, overrideBufLen: 10, // 2 (nh) + 2 + len(key) + 2 + len(val) - 2 wantBadErr: "buffer is too small", }, { msg: "no key value pairs", argScheme: Thrift, bufKV: map[string]string{}, }, { msg: "not tchannel thrift", argScheme: HTTP, bufKV: map[string]string{"key": "val"}, wantBadErr: "non-Thrift", }, { msg: "not arg scheme", bufKV: map[string]string{"key": "val"}, wantBadErr: "non-Thrift", }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { arg2Buf := thriftarg2test.BuildKVBuffer(tt.bufKV) if tt.overrideBufLen > 0 { arg2Buf = arg2Buf[:tt.overrideBufLen] } cr := crt.reqWithParams(t, testCallReqParams{ arg2Buf: arg2Buf, argScheme: tt.argScheme, }) gotIter := make(map[string]string) iter, err := cr.Arg2Iterator() for err == nil { gotIter[string(iter.Key())] = string(iter.Value()) iter, err = iter.Next() } if tt.wantBadErr != "" { require.NotEqual(t, io.EOF, err, "should not return EOF for iterator exit") assert.Contains(t, err.Error(), tt.wantBadErr) } else { assert.Equal(t, io.EOF, err, "should return EOF for iterator exit") wantKV := tt.wantKV if wantKV == nil { wantKV = tt.bufKV } assert.Equal(t, wantKV, gotIter, "unexpected arg2 keys, call req %+v", crt) } }) }) } t.Run("bad Arg2 length", func(t *testing.T) { withLazyCallReqCombinations(func(crt testCallReq) { crNoArg2 := crt.req(t) leftSpace := int(crNoArg2.Header.PayloadSize()) - crNoArg2.Arg2StartOffset() frm := crt.frameWithParams(t, testCallReqParams{ arg2Buf: make([]byte, leftSpace), argScheme: Thrift, overrideArg2Len: leftSpace + 5, // Arg2 length extends beyond payload }) _, err := newLazyCallReq(frm) assert.EqualError(t, err, "buffer is too small") }) }) } func TestLazyCallResRejectsOtherFrames(t *testing.T) { assertWrappingPanics( t, reqHasHeaders.req(t).Frame, func(f *Frame) { newLazyCallRes(f) }, ) } func TestLazyCallRes(t *testing.T) { withLazyCallResCombinations(t, func(t *testing.T, crt testCallRes) { cr := crt.res(t) // isOK if crt&resIsOK == 0 { assert.False(t, cr.OK(), "Expected call res to have a non-ok code.") } else { assert.True(t, cr.OK(), "Expected call res to have code ok.") } // isThrift if crt&resIsThrift != 0 { assert.Equal(t, Thrift.String(), string(cr.ArgScheme()), "Expected call res to have isThrift=true") assert.Equal(t, cr.as, _tchanThriftValueBytes, "Expected arg scheme to be thrift") } else { assert.NotEqual(t, Thrift.String(), string(cr.ArgScheme()), "Expected call res to have isThrift=false") assert.NotEqual(t, cr.as, _tchanThriftValueBytes, "Expected arg scheme to not be thrift") } // arg2IsFragmented if crt&resHasFragmentedArg2 != 0 { assert.True(t, cr.Arg2IsFragmented(), "Expected arg2 to be fragmented") } if crt&resIsThrift != 0 { iter, err := arg2.NewKeyValIterator(cr.Arg2()) if crt&resHasArg2 != 0 || crt&resHasFragmentedArg2 != 0 { require.NoError(t, err, "Got unexpected error for .Arg2()") kvMap := make(map[string]string) for ; err == nil; iter, err = iter.Next() { kvMap[string(iter.Key())] = string(iter.Value()) } if crt&resHasArg2 != 0 { for k, v := range exampleArg2Map { assert.Equal(t, kvMap[k], v) } } } else { require.Error(t, err, io.EOF, "Got unexpected error for .Arg2()") } } }) } func TestNewLazyCallResCorruptedFrame(t *testing.T) { _, err := newLazyCallRes(newCallResFrame(t, testCallResParams{ arg2Prefix: []byte{0, 100}, arg2KeyVals: exampleArg2Map, })) require.EqualError(t, err, "read response frame: buffer is too small", "Got unexpected error for corrupted frame") } func TestLazyErrorRejectsOtherFrames(t *testing.T) { assertWrappingPanics( t, reqHasHeaders.req(t).Frame, func(f *Frame) { newLazyError(f) }, ) } func TestLazyErrorCodes(t *testing.T) { withLazyErrorCombinations(func(ec SystemErrCode) { f := ec.fakeErrFrame() assert.Equal(t, ec, f.Code(), "Mismatch between error code and lazy frame's Code() method.") }) } // TODO(cinchurge): replace with e.g. decodeThriftHeader once we've resolved the import cycle func uint16KeyValToMap(tb testing.TB, buffer []byte) map[string]string { rbuf := typed.NewReadBuffer(buffer) nh := int(rbuf.ReadUint16()) retMap := make(map[string]string, nh) for i := 0; i < nh; i++ { keyLen := int(rbuf.ReadUint16()) key := rbuf.ReadBytes(keyLen) valLen := int(rbuf.ReadUint16()) val := rbuf.ReadBytes(valLen) retMap[string(key)] = string(val) } require.NoError(tb, rbuf.Err()) return retMap } func TestLazyCallReqContents(t *testing.T) { cr := reqHasAll.reqWithParams(t, testCallReqParams{ arg2Buf: thriftarg2test.BuildKVBuffer(exampleArg2Map), arg3Buf: []byte(exampleArg3Data), }) t.Run("checksum", func(t *testing.T) { assert.Equal(t, ChecksumTypeCrc32C, cr.checksumType, "Got unexpected checksum type") assert.Equal(t, byte(ChecksumTypeCrc32C), cr.Frame.Payload[cr.checksumTypeOffset], "Unexpected value read from checksum offset") }) t.Run(".arg2()", func(t *testing.T) { assert.Equal(t, exampleArg2Map, uint16KeyValToMap(t, cr.arg2()), "Got unexpected headers") }) t.Run(".arg3()", func(t *testing.T) { // TODO(echung): switch to assert.Equal once we have more robust test frame generation assert.Contains(t, string(cr.arg3()), exampleArg3Data, "Got unexpected headers") }) } func TestLazyCallReqLargeService(t *testing.T) { for _, svcSize := range []int{10, 100, 200, 240, math.MaxInt8} { t.Run(fmt.Sprintf("size=%v", svcSize), func(t *testing.T) { largeService := strings.Repeat("a", svcSize) withLazyCallReqCombinations(func(cr testCallReq) { f := cr.frameWithParams(t, testCallReqParams{ serviceOverride: largeService, }) callReq, err := newLazyCallReq(f) require.NoError(t, err, "newLazyCallReq failed") assert.Equal(t, largeService, string(callReq.Service()), "service name mismatch") }) }) } } ================================================ FILE: relay_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "bytes" "errors" "fmt" "io" "io/ioutil" "net" "os" "runtime" "strings" "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/benchmark" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/relay/relaytest" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/testreader" "github.com/uber/tchannel-go/testutils/thriftarg2test" "github.com/uber/tchannel-go/thrift" "github.com/uber/tchannel-go/thrift/arg2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "golang.org/x/net/context" ) type relayTest struct { testutils.TestServer } func serviceNameOpts(s string) *testutils.ChannelOpts { return testutils.NewOpts().SetServiceName(s) } func withRelayedEcho(t testing.TB, f func(relay, server, client *Channel, ts *testutils.TestServer)) { opts := serviceNameOpts("test"). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(serviceNameOpts("client")) client.Peers().Add(ts.HostPort()) f(ts.Relay(), ts.Server(), client, ts) }) } func TestRelay(t *testing.T) { withRelayedEcho(t, func(_, _, client *Channel, ts *testutils.TestServer) { tests := []struct { header string body string }{ {"fake-header", "fake-body"}, // fits in one frame {"fake-header", strings.Repeat("fake-body", 10000)}, // requires continuation } sc := client.GetSubChannel("test") for _, tt := range tests { ctx, cancel := NewContext(time.Second) defer cancel() arg2, arg3, _, err := raw.CallSC(ctx, sc, "echo", []byte(tt.header), []byte(tt.body)) require.NoError(t, err, "Relayed call failed.") assert.Equal(t, tt.header, string(arg2), "Header was mangled during relay.") assert.Equal(t, tt.body, string(arg3), "Body was mangled during relay.") } calls := relaytest.NewMockStats() for range tests { calls.Add("client", "test", "echo").Succeeded().End() } ts.AssertRelayStats(calls) }) } func TestRelaySetHost(t *testing.T) { rh := relaytest.NewStubRelayHost() opts := serviceNameOpts("test"). SetRelayHost(rh). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(serviceNameOpts("client")) client.Peers().Add(ts.HostPort()) testutils.AssertEcho(t, client, ts.HostPort(), ts.Server().ServiceName()) }) } func TestRelayHandlesClosedPeers(t *testing.T) { opts := serviceNameOpts("test"). SetRelayOnly(). SetCheckFramePooling(). // Disable logs as we are closing connections that can error in a lot of places. DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(300 * time.Millisecond) defer cancel() testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(serviceNameOpts("client")) client.Peers().Add(ts.HostPort()) sc := client.GetSubChannel("test") _, _, _, err := raw.CallSC(ctx, sc, "echo", []byte("fake-header"), []byte("fake-body")) require.NoError(t, err, "Relayed call failed.") ts.Server().Close() require.NotPanics(t, func() { raw.CallSC(ctx, sc, "echo", []byte("fake-header"), []byte("fake-body")) }) }) } func TestRelayConnectionCloseDrainsRelayItems(t *testing.T) { opts := serviceNameOpts("s1"). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ctx, cancel := NewContext(time.Second) defer cancel() s1 := ts.Server() s2 := ts.NewServer(serviceNameOpts("s2")) s2HP := s2.PeerInfo().HostPort testutils.RegisterEcho(s1, func() { // When s1 gets called, it calls Close on the connection from the relay to s2. conn, err := ts.Relay().Peers().GetOrAdd(s2HP).GetConnection(ctx) require.NoError(t, err, "Unexpected failure getting connection between s1 and relay") conn.Close() }) testutils.AssertEcho(t, s2, ts.HostPort(), "s1") calls := relaytest.NewMockStats() calls.Add("s2", "s1", "echo").Succeeded().End() ts.AssertRelayStats(calls) }) } func TestRelayIDClash(t *testing.T) { // TODO: enable framepool checks opts := serviceNameOpts("s1"). SetRelayOnly() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s1 := ts.Server() s2 := ts.NewServer(serviceNameOpts("s2")) unblock := make(chan struct{}) testutils.RegisterEcho(s1, func() { <-unblock }) testutils.RegisterEcho(s2, nil) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() testutils.AssertEcho(t, s2, ts.HostPort(), s1.ServiceName()) }() } for i := 0; i < 5; i++ { testutils.AssertEcho(t, s1, ts.HostPort(), s2.ServiceName()) } close(unblock) wg.Wait() }) } func TestRelayErrorsOnGetPeer(t *testing.T) { busyErr := NewSystemError(ErrCodeBusy, "busy") tests := []struct { desc string returnPeer string returnErr error statsKey string wantErr error }{ { desc: "No peer and no error", returnPeer: "", returnErr: nil, statsKey: "relay-bad-relay-host", wantErr: NewSystemError(ErrCodeDeclined, `bad relay host implementation`), }, { desc: "System error getting peer", returnErr: busyErr, statsKey: "relay-busy", wantErr: busyErr, }, { desc: "Unknown error getting peer", returnErr: errors.New("unknown"), statsKey: "relay-declined", wantErr: NewSystemError(ErrCodeDeclined, "unknown"), }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { f := func(relay.CallFrame, *relay.Conn) (string, error) { return tt.returnPeer, tt.returnErr } opts := testutils.NewOpts(). SetRelayHost(relaytest.HostFunc(f)). SetRelayOnly(). SetCheckFramePooling(). DisableLogVerification() // some of the test cases cause warnings. testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := ts.NewClient(nil) err := testutils.CallEcho(client, ts.HostPort(), "svc", nil) require.Error(t, err, "Call to unknown service should fail") assert.Equal(t, tt.wantErr, err, "unexpected error") calls := relaytest.NewMockStats() calls.Add(client.PeerInfo().ServiceName, "svc", "echo"). Failed(tt.statsKey).End() ts.AssertRelayStats(calls) }) }) } } func TestErrorFrameEndsRelay(t *testing.T) { // TestServer validates that there are no relay items left after the given func. opts := serviceNameOpts("svc"). SetRelayOnly(). SetCheckFramePooling(). DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := ts.NewClient(nil) err := testutils.CallEcho(client, ts.HostPort(), "svc", nil) if !assert.Error(t, err, "Expected error due to unknown method") { return } se, ok := err.(SystemError) if !assert.True(t, ok, "err should be a SystemError, got %T", err) { return } assert.Equal(t, ErrCodeBadRequest, se.Code(), "Expected BadRequest error") calls := relaytest.NewMockStats() calls.Add(client.PeerInfo().ServiceName, "svc", "echo").Failed("bad-request").End() ts.AssertRelayStats(calls) }) } // Trigger a race between receiving a new call and a connection closing // by closing the relay while a lot of background calls are being made. func TestRaceCloseWithNewCall(t *testing.T) { opts := serviceNameOpts("s1"). SetRelayOnly(). SetCheckFramePooling(). DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s1 := ts.Server() s2 := ts.NewServer(serviceNameOpts("s2").DisableLogVerification()) testutils.RegisterEcho(s1, nil) // signal to start closing the relay. var ( closeRelay sync.WaitGroup stopCalling atomic.Int32 callers sync.WaitGroup ) for i := 0; i < 5; i++ { callers.Add(1) closeRelay.Add(1) go func() { defer callers.Done() calls := 0 for stopCalling.Load() == 0 { testutils.CallEcho(s2, ts.HostPort(), "s1", nil) calls++ if calls == 5 { closeRelay.Done() } } }() } closeRelay.Wait() // Close the relay, wait for it to close. ts.Relay().Close() closed := testutils.WaitFor(time.Second, func() bool { return ts.Relay().State() == ChannelClosed }) assert.True(t, closed, "Relay did not close within timeout") // Now stop all calls, and wait for the calling goroutine to end. stopCalling.Inc() callers.Wait() }) } func TestTimeoutCallsThenClose(t *testing.T) { // TODO: enable framepool checks // Test needs at least 2 CPUs to trigger race conditions. defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(2)) opts := serviceNameOpts("s1"). SetRelayOnly(). DisableLogVerification() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s1 := ts.Server() s2 := ts.NewServer(serviceNameOpts("s2").DisableLogVerification()) unblockEcho := make(chan struct{}) testutils.RegisterEcho(s1, func() { <-unblockEcho }) ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() var callers sync.WaitGroup for i := 0; i < 100; i++ { callers.Add(1) go func() { defer callers.Done() raw.Call(ctx, s2, ts.HostPort(), "s1", "echo", nil, nil) }() } close(unblockEcho) // Wait for all the callers to end callers.Wait() }) } func TestLargeTimeoutsAreClamped(t *testing.T) { const ( clampTTL = time.Millisecond longTTL = time.Minute ) opts := serviceNameOpts("echo-service"). SetRelayOnly(). SetCheckFramePooling(). SetRelayMaxTimeout(clampTTL). DisableLogVerification() // handler returns after deadline testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { srv := ts.Server() client := ts.NewClient(nil) unblock := make(chan struct{}) defer close(unblock) // let server shut down cleanly testutils.RegisterFunc(srv, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { now := time.Now() deadline, ok := ctx.Deadline() assert.True(t, ok, "Expected deadline to be set in handler.") assert.True(t, deadline.Sub(now) <= clampTTL, "Expected relay to clamp TTL sent to backend.") <-unblock return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil }) done := make(chan struct{}) go func() { ctx, cancel := NewContext(longTTL) defer cancel() _, _, _, err := raw.Call(ctx, client, ts.HostPort(), "echo-service", "echo", nil, nil) require.Error(t, err) code := GetSystemErrorCode(err) assert.Equal(t, ErrCodeTimeout, code) close(done) }() // This test is very sensitive to system noise, where a spike of latency in the relay (e.g. caused by load) // is able to cause the client call to timeout, making this test prone to false positives. As such, we // can't time out too close to clampTTL, but instead check that we don't time out after longTTL/2. This might // be a bit generous, but should be sufficient for our purposes here. select { case <-time.After(testutils.Timeout(longTTL / 2)): t.Fatal("Failed to clamp timeout.") case <-done: } }) } // TestRelayConcurrentCalls makes many concurrent calls and ensures that // we don't try to reuse any frames once they've been released. func TestRelayConcurrentCalls(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { server := benchmark.NewServer( benchmark.WithNoLibrary(), benchmark.WithServiceName("s1"), ) defer server.Close() ts.RelayHost().Add("s1", server.HostPort()) client := benchmark.NewClient([]string{ts.HostPort()}, benchmark.WithNoDurations(), // TODO(prashant): Enable once we have control over concurrency with NoLibrary. // benchmark.WithNoLibrary(), benchmark.WithNumClients(20), benchmark.WithServiceName("s1"), benchmark.WithTimeout(time.Minute), ) defer client.Close() require.NoError(t, client.Warmup(), "Client warmup failed") _, err := client.RawCall(1000) assert.NoError(t, err, "RawCalls failed") }) } // Ensure that any connections created in the relay path send the ephemeral // host:port. func TestRelayOutgoingConnectionsEphemeral(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s2 := ts.NewServer(serviceNameOpts("s2")) testutils.RegisterFunc(s2, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { assert.True(t, CurrentCall(ctx).RemotePeer().IsEphemeral, "Connections created for the relay should send ephemeral host:port header") return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, }, nil }) require.NoError(t, testutils.CallEcho(ts.Server(), ts.HostPort(), "s2", nil), "CallEcho failed") }) } func TestRelayHandleLocalCall(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayLocal("relay", "tchannel", "test"). // We make a call to "test" for an unknown method. AddLogFilter("Couldn't find handler.", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s2 := ts.NewServer(serviceNameOpts("s2")) testutils.RegisterEcho(s2, nil) client := ts.NewClient(nil) testutils.AssertEcho(t, client, ts.HostPort(), "s2") testutils.RegisterEcho(ts.Relay(), nil) testutils.AssertEcho(t, client, ts.HostPort(), "relay") // Sould get a bad request for "test" since the channel does not handle it. err := testutils.CallEcho(client, ts.HostPort(), "test", nil) assert.Equal(t, ErrCodeBadRequest, GetSystemErrorCode(err), "Expected BadRequest for test") // But an unknown service causes declined err = testutils.CallEcho(client, ts.HostPort(), "unknown", nil) assert.Equal(t, ErrCodeDeclined, GetSystemErrorCode(err), "Expected Declined for unknown") calls := relaytest.NewMockStats() calls.Add(client.ServiceName(), "s2", "echo").Succeeded().End() calls.Add(client.ServiceName(), "unknown", "echo").Failed("relay-declined").End() ts.AssertRelayStats(calls) }) } func TestRelayHandleLargeLocalCall(t *testing.T) { // TODO: enablle framepool checks opts := testutils.NewOpts().SetRelayOnly(). SetRelayLocal("relay"). AddLogFilter("Received fragmented callReq", 1). // Expect 4 callReqContinues for 256 kb payload that we cannot relay. AddLogFilter("Failed to relay frame.", 4) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := ts.NewClient(nil) testutils.RegisterEcho(ts.Relay(), nil) // This large call should fail with a bad request. err := testutils.CallEcho(client, ts.HostPort(), "relay", &raw.Args{ Arg2: testutils.RandBytes(128 * 1024), Arg3: testutils.RandBytes(128 * 1024), }) if assert.Equal(t, ErrCodeBadRequest, GetSystemErrorCode(err), "Expected BadRequest for large call to relay") { assert.Contains(t, err.Error(), "cannot receive fragmented calls") } // We may get an error before the call is finished flushing. // Do a ping to ensure everything has been flushed. ctx, cancel := NewContext(time.Second) defer cancel() require.NoError(t, client.Ping(ctx, ts.HostPort()), "Ping failed") }) } func TestRelayMakeOutgoingCall(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { svr1 := ts.Relay() svr2 := ts.NewServer(testutils.NewOpts().SetServiceName("svc2")) testutils.RegisterEcho(svr2, nil) sizes := []int{128, 1024, 128 * 1024} for _, size := range sizes { t.(*testing.T).Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { err := testutils.CallEcho(svr1, ts.HostPort(), "svc2", &raw.Args{ Arg2: testutils.RandBytes(size), Arg3: testutils.RandBytes(size), }) assert.NoError(t, err, "Echo with size %v failed", size) }) } }) } func TestRelayInboundConnContext(t *testing.T) { rh := relaytest.NewStubRelayHost() rh.SetFrameFn(func(f relay.CallFrame, conn *relay.Conn) { // Verify that the relay gets the base context set in the server's ConnContext assert.Equal(t, "bar", conn.Context.Value("foo"), "Unexpected value set in base context") }) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh). SetConnContext(func(ctx context.Context, conn net.Conn) context.Context { return context.WithValue(ctx, "foo", "bar") }) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { rly := ts.Relay() svr := ts.Server() testutils.RegisterEcho(svr, nil) client := testutils.NewClient(t, nil) testutils.AssertEcho(t, client, rly.PeerInfo().HostPort, ts.ServiceName()) }) } func TestRelayContextInheritsFromOutboundConnection(t *testing.T) { rh := relaytest.NewStubRelayHost() rh.SetFrameFn(func(f relay.CallFrame, conn *relay.Conn) { // Verify that the relay gets the base context set by the outbound connection to the caller assert.Equal(t, "bar", conn.Context.Value("foo"), "Unexpected value set in base context") }) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { rly := ts.Relay() callee := ts.Server() testutils.RegisterEcho(callee, nil) caller := ts.NewServer(testutils.NewOpts()) testutils.RegisterEcho(caller, nil) baseCtx := context.WithValue(context.Background(), "foo", "bar") ctx, cancel := NewContextBuilder(time.Second).SetConnectBaseContext(baseCtx).Build() defer cancel() require.NoError(t, rly.Ping(ctx, caller.PeerInfo().HostPort)) testutils.AssertEcho(t, caller, ts.HostPort(), ts.ServiceName()) }) } func TestRelayConnection(t *testing.T) { var errTest = errors.New("test") var gotConn *relay.Conn getHost := func(_ relay.CallFrame, conn *relay.Conn) (string, error) { gotConn = conn return "", errTest } opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(relaytest.HostFunc(getHost)) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { getConn := func(ch *Channel, outbound bool) ConnectionRuntimeState { state := ch.IntrospectState(nil) peer, ok := state.RootPeers[ts.HostPort()] require.True(t, ok, "Failed to find peer for relay") conns := peer.InboundConnections if outbound { conns = peer.OutboundConnections } require.Len(t, conns, 1, "Expect single connection from client to relay") return conns[0] } // Create a client that is listening so we can set the expected host:port. client := ts.NewClient(nil) err := testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil) require.Error(t, err, "Expected CallEcho to fail") assert.Contains(t, err.Error(), errTest.Error(), "Unexpected error") wantConn := &relay.Conn{ RemoteAddr: getConn(client, true /* outbound */).LocalHostPort, RemoteProcessName: client.PeerInfo().ProcessName, IsOutbound: false, Context: context.Background(), } assert.Equal(t, wantConn, gotConn, "Unexpected remote addr") // Verify something similar with a listening channel, ensuring that // we're not using the host:port of the listening server, but the // host:port of the outbound TCP connection. listeningC := ts.NewServer(nil) err = testutils.CallEcho(listeningC, ts.HostPort(), ts.ServiceName(), nil) require.Error(t, err, "Expected CallEcho to fail") assert.Contains(t, err.Error(), errTest.Error(), "Unexpected error") connHostPort := getConn(listeningC, true /* outbound */).LocalHostPort assert.NotEqual(t, connHostPort, listeningC.PeerInfo().HostPort, "Ensure connection host:port is not listening host:port") wantConn = &relay.Conn{ RemoteAddr: connHostPort, RemoteProcessName: listeningC.PeerInfo().ProcessName, Context: context.Background(), } assert.Equal(t, wantConn, gotConn, "Unexpected remote addr") // Connections created when relaying hide the relay host:port to ensure // services don't send calls back over that same connection. However, // this is what happens in the hyperbahn emulation case, so create // an explicit connection to a new listening channel. listeningHBSvc := ts.NewServer(nil) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // Ping to ensure the connection has been added to peers on both sides. err = ts.Relay().Ping(ctx, listeningHBSvc.PeerInfo().HostPort) require.NoError(t, err, "Failed to connect from relay to listening host:port") // Now when listeningHBSvc makes a call, it should use the above connection. err = testutils.CallEcho(listeningHBSvc, ts.HostPort(), ts.ServiceName(), nil) require.Error(t, err, "Expected CallEcho to fail") assert.Contains(t, err.Error(), errTest.Error(), "Unexpected error") // We expect an inbound connection on listeningHBSvc. connHostPort = getConn(listeningHBSvc, false /* outbound */).LocalHostPort wantConn = &relay.Conn{ RemoteAddr: connHostPort, RemoteProcessName: listeningHBSvc.PeerInfo().ProcessName, IsOutbound: true, // outbound connection according to relay. Context: context.Background(), } assert.Equal(t, wantConn, gotConn, "Unexpected remote addr") }) } func TestRelayConnectionClosed(t *testing.T) { protocolErr := NewSystemError(ErrCodeProtocol, "invalid service name") getHost := func(relay.CallFrame, *relay.Conn) (string, error) { return "", protocolErr } opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(relaytest.HostFunc(getHost)) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // The client receives a protocol error which causes the following logs. opts := testutils.NewOpts(). AddLogFilter("Peer reported protocol error", 1). AddLogFilter("Connection error", 1) client := ts.NewClient(opts) err := testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil) assert.Equal(t, protocolErr, err, "Unexpected error on call") closedAll := testutils.WaitFor(time.Second, func() bool { return ts.Relay().IntrospectNumConnections() == 0 }) assert.True(t, closedAll, "Relay should close client connection") }) } func TestRelayUsesRootPeers(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := testutils.NewClient(t, nil) err := testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil) assert.NoError(t, err, "Echo failed") assert.Len(t, ts.Relay().Peers().Copy(), 0, "Peers should not be modified by relay") }) } // Ensure that if the relay recieves a call on a connection that is not active, // it declines the call, and increments a relay-client-conn-inactive stat. func TestRelayRejectsDuringClose(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). AddLogFilter("Failed to relay frame.", 1, "error", "incoming connection is not active: connectionStartClose") testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { gotCall := make(chan struct{}) block := make(chan struct{}) testutils.RegisterEcho(ts.Server(), func() { close(gotCall) <-block }) client := ts.NewClient(nil) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) }() <-gotCall // Close the relay so that it stops accepting more calls. ts.Relay().Close() err := testutils.CallEcho(client, ts.HostPort(), ts.ServiceName(), nil) require.Error(t, err, "Expect call to fail after relay is shutdown") assert.Contains(t, err.Error(), "incoming connection is not active") close(block) wg.Wait() // We have a successful call that ran in the goroutine // and a failed call that we just checked the error on. calls := relaytest.NewMockStats() calls.Add(client.PeerInfo().ServiceName, ts.ServiceName(), "echo"). Succeeded().End() calls.Add(client.PeerInfo().ServiceName, ts.ServiceName(), "echo"). // No peer is set since we rejected the call before selecting one. Failed("relay-client-conn-inactive").End() ts.AssertRelayStats(calls) }) } func TestRelayRateLimitDrop(t *testing.T) { getHost := func(relay.CallFrame, *relay.Conn) (string, error) { return "", relay.RateLimitDropError{} } opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(relaytest.HostFunc(getHost)) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { var gotCall bool testutils.RegisterEcho(ts.Server(), func() { gotCall = true }) client := ts.NewClient(nil) var wg sync.WaitGroup wg.Add(1) go func() { // We want to use a low timeout here since the test waits for this // call to timeout. ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() _, _, _, err := raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) require.Equal(t, ErrTimeout, err, "Expected CallEcho to fail") defer wg.Done() }() wg.Wait() assert.False(t, gotCall, "Server should not receive a call") calls := relaytest.NewMockStats() calls.Add(client.PeerInfo().ServiceName, ts.ServiceName(), "echo"). Failed("relay-dropped").End() ts.AssertRelayStats(calls) }) } // Test that a stalled connection to a single server does not block all calls // from that server, and we have stats to capture that this is happening. func TestRelayStalledConnection(t *testing.T) { // TODO(ablackmon): Debug why this is flaky in github if os.Getenv("GITHUB_WORKFLOW") != "" { t.Skip("skipping test flaky in github actions.") } // TODO: enable framepool checks opts := testutils.NewOpts(). AddLogFilter("Dropping call due to slow connection.", 1, "sendChCapacity", "32"). SetSendBufferSize(32). // We want to hit the buffer size earlier, but also ensure we're only dropping once the sendCh is full. SetServiceName("s1"). SetRelayOnly() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s2 := ts.NewServer(testutils.NewOpts().SetServiceName("s2")) testutils.RegisterEcho(s2, nil) stall := make(chan struct{}) stallComplete := make(chan struct{}) stallHandler := func(ctx context.Context, call *InboundCall) { <-stall raw.ReadArgs(call) close(stallComplete) } ts.Register(HandlerFunc(stallHandler), "echo") ctx, cancel := NewContext(testutils.Timeout(300 * time.Millisecond)) defer cancel() client := ts.NewClient(nil) call, err := client.BeginCall(ctx, ts.HostPort(), ts.ServiceName(), "echo", nil) require.NoError(t, err, "BeginCall failed") writer, err := call.Arg2Writer() require.NoError(t, err, "Arg2Writer failed") go io.Copy(writer, testreader.Looper([]byte("test"))) // Try to read the response which might get an error. readDone := make(chan struct{}) go func() { defer close(readDone) _, err := call.Response().Arg2Reader() if assert.Error(t, err, "Expected error while reading") { assert.Contains(t, err.Error(), "frame was not sent to remote side") } }() // Wait for the reader to error out. select { case <-time.After(testutils.Timeout(10 * time.Second)): t.Fatalf("Test timed out waiting for reader to fail") case <-readDone: } // We should be able to make calls to s2 even if s1 is stalled. testutils.AssertEcho(t, client, ts.HostPort(), "s2") // Verify the sendCh is full, and the buffers are utilized. state := ts.Relay().IntrospectState(&IntrospectionOptions{}) connState := state.RootPeers[ts.Server().PeerInfo().HostPort].OutboundConnections[0] assert.Equal(t, 32, connState.SendChCapacity, "unexpected SendChCapacity") assert.NotZero(t, connState.SendChQueued, "unexpected SendChQueued") assert.NotZero(t, connState.SendBufferUsage, "unexpected SendBufferUsage") assert.NotZero(t, connState.SendBufferSize, "unexpected SendBufferSize") // Cancel the call and unblock the stall handler. cancel() close(stall) // The server channel will not close until the stall handler receives // an error. Since we don't propagate cancels, the handler will keep // trying to read arguments till the timeout. select { case <-stallComplete: case <-time.After(testutils.Timeout(300 * time.Millisecond)): t.Fatalf("Stall handler did not complete") } calls := relaytest.NewMockStats() calls.Add(client.PeerInfo().ServiceName, ts.ServiceName(), "echo"). Failed("relay-dest-conn-slow").End() calls.Add(client.PeerInfo().ServiceName, "s2", "echo"). Succeeded().End() ts.AssertRelayStats(calls) }) } // Test that a stalled connection to the client does not cause stuck calls // See https://github.com/uber/tchannel-go/issues/700 for more info. func TestRelayStalledClientConnection(t *testing.T) { // This needs to be large enough to fill up the client TCP buffer. const _calls = 100 // TODO: enable framepool checks opts := testutils.NewOpts(). // Expect errors from dropped frames. AddLogFilter("Dropping call due to slow connection.", _calls). SetSendBufferSize(10). // We want to hit the buffer size earlier. SetServiceName("s1"). SetRelayOnly() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // Track when the server receives calls gotCall := make(chan struct{}, _calls) testutils.RegisterEcho(ts.Server(), func() { gotCall <- struct{}{} }) // Create a frame relay that will block all client inbound frames. unblockClientInbound := make(chan struct{}) blockerHostPort, relayCancel := testutils.FrameRelay(t, ts.HostPort(), func(outgoing bool, f *Frame) *Frame { if !outgoing && f.Header.ID > 1 { // Block all inbound frames except the initRes <-unblockClientInbound } return f }) defer relayCancel() defer close(unblockClientInbound) client := ts.NewClient(nil) ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() var calls []*OutboundCall // Data to fit one frame fully, but large enough that a number of these frames will fill // all the buffers and cause the relay to drop the response frame. Buffers are: // 1. Relay's sendCh on the connection to the client (set to 10 frames explicitly) // 2. Relay's TCP send buffer for the connection to the client. // 3. Client's TCP receive buffer on the connection to the relay. data := bytes.Repeat([]byte("test"), 256*60) for i := 0; i < _calls; i++ { call, err := client.BeginCall(ctx, blockerHostPort, ts.ServiceName(), "echo", nil) require.NoError(t, err, "BeginCall failed") require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(nil), "arg2 write failed") require.NoError(t, NewArgWriter(call.Arg3Writer()).Write(data), "arg2 write failed") // Wait for server to receive the call <-gotCall calls = append(calls, call) } // Wait for all calls to end on the relay, and ensure we got failures from the slow client. stats := ts.RelayHost().Stats() stats.WaitForEnd() assert.Contains(t, stats.Map(), "testService-client->s1::echo.failed-relay-source-conn-slow", "Expect at least 1 failed call due to slow client") // We don't read the responses, as we want the client's TCP buffers to fill up // and the relay to drop calls. However, we should unblock the client reader // to make sure the client channel can close. // Unblock the client so it can close. cancel() for _, call := range calls { require.Error(t, NewArgReader(call.Response().Arg2Reader()).Read(&data), "should fail to read response") } }) } // Test that a corrupted callRes frame results in log emission. We set up the following: // // client <-> relay <-> man-in-the-middle (MITM) relay <-> server // // The MITM relay is configured to intercept and corrupt response frames (through truncation) // sent back from the server, and forward them back to the relay, where it is checked for errors. func TestRelayCorruptedCallResFrame(t *testing.T) { // TODO: Debug why this is flaky in github if os.Getenv("GITHUB_WORKFLOW") != "" { t.Skip("skipping test flaky in github actions.") } opts := testutils.NewOpts(). // Expect errors from corrupted callRes frames. AddLogFilter("Malformed callRes frame.", 1). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { s1 := testutils.NewServer(t, testutils.NewOpts().SetServiceName("s1")) defer s1.Close() // Track when the server receives the call gotCall := make(chan struct{}) testutils.RegisterFunc(s1, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { gotCall <- struct{}{} return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil }) mitmHostPort, relayCancel := testutils.FrameRelay(t, s1.PeerInfo().HostPort, func(outgoing bool, f *Frame) *Frame { // We care only about callRes frames if f.Header.MessageType() == 0x04 { // Corrupt the frame by truncating its payload size to 1 byte f.Header.SetPayloadSize(1) } return f }) defer relayCancel() // The relay only forwards requests to the MITM relay ts.RelayHost().Add("s1", mitmHostPort) client := ts.NewClient(nil) defer client.Close() ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() data := bytes.Repeat([]byte("test"), 256*60) call, err := client.BeginCall(ctx, ts.Relay().PeerInfo().HostPort, "s1", "echo", nil) require.NoError(t, err, "BeginCall failed") require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(nil), "arg2 write failed") require.NoError(t, NewArgWriter(call.Arg3Writer()).Write(data), "arg2 write failed") // Wait for server to receive the call <-gotCall // Unblock the client so it can close. cancel() require.Error(t, NewArgReader(call.Response().Arg2Reader()).Read(&data), "should fail to read response") }) } func TestRelayThroughSeparateRelay(t *testing.T) { // TODO: enable framepool checks opts := testutils.NewOpts(). SetRelayOnly() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { serverHP := ts.Server().PeerInfo().HostPort dummyFactory := func(relay.CallFrame, *relay.Conn) (string, error) { panic("should not get invoked") } relay2Opts := testutils.NewOpts().SetRelayHost(relaytest.HostFunc(dummyFactory)) relay2 := ts.NewServer(relay2Opts) // Override where the peers come from. ts.RelayHost().SetChannel(relay2) relay2.GetSubChannel(ts.ServiceName(), Isolated).Peers().Add(serverHP) testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(nil) testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) numConns := func(p PeerRuntimeState) int { return len(p.InboundConnections) + len(p.OutboundConnections) } // Verify that there are no connections from ts.Relay() to the server. introspected := ts.Relay().IntrospectState(nil) assert.Zero(t, numConns(introspected.RootPeers[serverHP]), "Expected no connections from relay to server") introspected = relay2.IntrospectState(nil) assert.Equal(t, 1, numConns(introspected.RootPeers[serverHP]), "Expected 1 connection from relay2 to server") }) } func TestRelayConcurrentNewConnectionAttempts(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // Create a server that is slow to accept connections by using // a frame relay to slow down the initial message. slowServer := testutils.NewServer(t, serviceNameOpts("slow-server")) defer slowServer.Close() testutils.RegisterEcho(slowServer, nil) var delayed atomic.Bool relayFunc := func(outgoing bool, f *Frame) *Frame { if !delayed.Load() { time.Sleep(testutils.Timeout(50 * time.Millisecond)) delayed.Store(true) } return f } slowHP, close := testutils.FrameRelay(t, slowServer.PeerInfo().HostPort, relayFunc) defer close() ts.RelayHost().Add("slow-server", slowHP) // Make concurrent calls to trigger concurrent getConnectionRelay calls. var wg sync.WaitGroup for i := 0; i < 5; i++ { wg.Add(1) // Create client and get dest host:port in the main goroutine to avoid races. client := ts.NewClient(nil) relayHostPort := ts.HostPort() go func() { defer wg.Done() testutils.AssertEcho(t, client, relayHostPort, "slow-server") }() } wg.Wait() // Verify that the slow server only received a single connection. inboundConns := 0 for _, state := range slowServer.IntrospectState(nil).RootPeers { inboundConns += len(state.InboundConnections) } assert.Equal(t, 1, inboundConns, "Expected a single inbound connection to the server") }) } func TestRelayRaceTimerCausesStuckConnectionOnClose(t *testing.T) { // TODO: enable framepool checks // TODO(ablackmon): Debug why this is flaky in github if os.Getenv("GITHUB_WORKFLOW") != "" { t.Skip("skipping test flaky in github actions.") } const ( concurrentClients = 15 callsPerClient = 100 ) opts := testutils.NewOpts(). SetRelayOnly(). SetSendBufferSize(concurrentClients * callsPerClient) // Avoid dropped frames causing unexpected logs. testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) // Create clients and ensure we can make a successful request. clients := make([]*Channel, concurrentClients) var callTime time.Duration for i := range clients { clients[i] = ts.NewClient(opts) started := time.Now() testutils.AssertEcho(t, clients[i], ts.HostPort(), ts.ServiceName()) callTime = time.Since(started) } // Overwrite the echo method with one that times out for the test. ts.Server().Register(HandlerFunc(func(ctx context.Context, call *InboundCall) { call.Response().Blackhole() }), "echo") var wg sync.WaitGroup for i := 0; i < concurrentClients; i++ { wg.Add(1) go func(client *Channel) { defer wg.Done() for j := 0; j < callsPerClient; j++ { // Make many concurrent calls which, some of which should timeout. ctx, cancel := NewContext(callTime) raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) cancel() } }(clients[i]) } wg.Wait() }) } func TestRelayRaceCompletionAndTimeout(t *testing.T) { // TODO: enable framepool checks const numCalls = 100 opts := testutils.NewOpts(). AddLogFilter("simpleHandler OnError.", numCalls). // Trigger deletion on timeout, see https://github.com/uber/tchannel-go/issues/808. SetRelayMaxTombs(numCalls/2). // Hitting max tombs will cause the following logs: AddLogFilter("Too many tombstones, deleting relay item immediately.", numCalls). AddLogFilter("Received a frame without a RelayItem.", numCalls). AddLogFilter("Attempted to create new mex after mexset shutdown.", numCalls). SetRelayOnly() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { testutils.RegisterEcho(ts.Server(), nil) client := ts.NewClient(nil) started := time.Now() testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) callTime := time.Since(started) // Make many calls with the same timeout, with the goal of // timing out right as we process the response frame. var wg sync.WaitGroup for i := 0; i < numCalls; i++ { wg.Add(1) go func() { defer wg.Done() ctx, cancel := NewContext(callTime) raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), "echo", nil, nil) cancel() }() } // Some of those calls should triger the race. wg.Wait() }) } func TestRelayArg2OffsetIntegration(t *testing.T) { ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() rh := relaytest.NewStubRelayHost() frameCh := inspectFrames(rh) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh) testutils.WithTestServer(t, opts, func(tb testing.TB, ts *testutils.TestServer) { const ( testMethod = "echo" arg2Data = "arg2-is" arg3Data = "arg3-here" ) var ( wantArg2Start = len(ts.ServiceName()) + len(testMethod) + 70 /*data before arg1*/ payloadLeft = MaxFramePayloadSize - wantArg2Start ) testutils.RegisterEcho(ts.Server(), nil) client := testutils.NewClient(t, nil /*opts*/) defer client.Close() tests := []struct { msg string arg2Data string arg2Flush bool arg2PostFlushData string noArg3 bool wantEndOffset int wantHasMore bool }{ { msg: "all within a frame", arg2Data: arg2Data, wantEndOffset: wantArg2Start + len(arg2Data), wantHasMore: false, }, { msg: "arg2 flushed", arg2Data: arg2Data, arg2Flush: true, wantEndOffset: wantArg2Start + len(arg2Data), wantHasMore: true, }, { msg: "arg2 flushed called then write again", arg2Data: arg2Data, arg2Flush: true, arg2PostFlushData: "more data", wantEndOffset: wantArg2Start + len(arg2Data), wantHasMore: true, }, { msg: "no arg2 but flushed", wantEndOffset: wantArg2Start, wantHasMore: false, }, { msg: "XL arg2 which is fragmented", arg2Data: string(make([]byte, MaxFrameSize+100)), wantEndOffset: wantArg2Start + payloadLeft, wantHasMore: true, }, { msg: "large arg2 with 3 bytes left for arg3", arg2Data: string(make([]byte, payloadLeft-3)), wantEndOffset: wantArg2Start + payloadLeft - 3, wantHasMore: false, }, { msg: "large arg2, 2 bytes left", arg2Data: string(make([]byte, payloadLeft-2)), wantEndOffset: wantArg2Start + payloadLeft - 2, wantHasMore: true, // no arg3 }, { msg: "large arg2, 2 bytes left, no arg3", arg2Data: string(make([]byte, payloadLeft-2)), wantEndOffset: wantArg2Start + payloadLeft - 2, noArg3: true, wantHasMore: true, // no arg3 and still got CALL_REQ_CONTINUE }, { msg: "large arg2, 1 bytes left", arg2Data: string(make([]byte, payloadLeft-1)), wantEndOffset: wantArg2Start + payloadLeft - 1, wantHasMore: true, // no arg3 }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { call, err := client.BeginCall(ctx, ts.HostPort(), ts.ServiceName(), testMethod, nil) require.NoError(t, err, "BeginCall failed") writer, err := call.Arg2Writer() require.NoError(t, err) _, err = writer.Write([]byte(tt.arg2Data)) require.NoError(t, err) if tt.arg2Flush { writer.Flush() // tries to write after flush if tt.arg2PostFlushData != "" { _, err := writer.Write([]byte(tt.arg2PostFlushData)) require.NoError(t, err) } } require.NoError(t, writer.Close()) arg3DataToWrite := arg3Data if tt.noArg3 { arg3DataToWrite = "" } require.NoError(t, NewArgWriter(call.Arg3Writer()).Write([]byte(arg3DataToWrite)), "arg3 write failed") f := <-frameCh start := f.Arg2StartOffset() end, hasMore := f.Arg2EndOffset() assert.Equal(t, wantArg2Start, start, "arg2 start offset does not match expectation") assert.Equal(t, tt.wantEndOffset, end, "arg2 end offset does not match expectation") assert.Equal(t, tt.wantHasMore, hasMore, "arg2 hasMore bit does not match expectation") gotArg2, gotArg3, err := raw.ReadArgsV2(call.Response()) assert.NoError(t, err) assert.Equal(t, tt.arg2Data+tt.arg2PostFlushData, string(gotArg2), "arg2 in response does not meet expectation") assert.Equal(t, arg3DataToWrite, string(gotArg3), "arg3 in response does not meet expectation") }) } }) } func TestRelayThriftArg2KeyValueIteration(t *testing.T) { ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() rh := relaytest.NewStubRelayHost() frameCh := inspectFrames(rh) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh) testutils.WithTestServer(t, opts, func(tb testing.TB, ts *testutils.TestServer) { kv := map[string]string{ "key": "val", "key2": "valval", "longkey": "valvalvalval", } arg2Buf := thriftarg2test.BuildKVBuffer(kv) const ( testMethod = "echo" arg3Data = "arg3-here" ) testutils.RegisterEcho(ts.Server(), nil) client := testutils.NewClient(t, nil /*opts*/) defer client.Close() call, err := client.BeginCall(ctx, ts.HostPort(), ts.ServiceName(), testMethod, &CallOptions{Format: Thrift}) require.NoError(t, err, "BeginCall failed") require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(arg2Buf), "arg2 write failed") require.NoError(t, NewArgWriter(call.Arg3Writer()).Write([]byte(arg3Data)), "arg3 write failed") f := <-frameCh iter, err := f.Arg2Iterator() gotKV := make(map[string]string) for err == nil { gotKV[string(iter.Key())] = string(iter.Value()) iter, err = iter.Next() } assert.Equal(t, kv, gotKV) assert.Equal(t, io.EOF, err) gotArg2, gotArg3, err := raw.ReadArgsV2(call.Response()) assert.NoError(t, err) assert.Equal(t, string(arg2Buf), string(gotArg2), "arg2 in response does not meet expectation") assert.Equal(t, arg3Data, string(gotArg3), "arg3 in response does not meet expectation") }) } func TestRelayConnectionTimeout(t *testing.T) { var ( minTimeout = testutils.Timeout(100 * time.Millisecond) maxTimeout = testutils.Timeout(time.Minute) ) tests := []struct { msg string callTimeout time.Duration maxConnTimeout time.Duration minTime time.Duration }{ { msg: "only call timeout is set", callTimeout: 2 * minTimeout, }, { msg: "call timeout < relay timeout", callTimeout: 2 * minTimeout, maxConnTimeout: 2 * maxTimeout, }, { msg: "relay timeout < call timeout", callTimeout: 2 * maxTimeout, maxConnTimeout: 2 * minTimeout, }, { msg: "relay timeout == call timeout", callTimeout: 2 * minTimeout, maxConnTimeout: 2 * minTimeout, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayMaxConnectionTimeout(tt.maxConnTimeout). AddLogFilter("Failed during connection handshake.", 1). AddLogFilter("Failed to connect to relay host.", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "Failed to listen") defer ln.Close() // TCP listener will never complete the handshake and always timeout. ts.RelayHost().Add("blocked", ln.Addr().String()) start := time.Now() ctx, cancel := NewContext(testutils.Timeout(tt.callTimeout)) defer cancel() // We expect connection error logs from the client. client := ts.NewClient(nil /* opts */) _, _, _, err = raw.Call(ctx, client, ts.HostPort(), "blocked", "echo", nil, nil) assert.Equal(t, ErrTimeout, err) taken := time.Since(start) if taken < minTimeout || taken > maxTimeout { t.Errorf("Took %v, expected [%v, %v]", taken, minTimeout, maxTimeout) } }) }) } } func TestRelayTransferredBytes(t *testing.T) { const ( kb = 1024 // The maximum delta between the payload size and the bytes on wire. protocolBuffer = kb ) rh := relaytest.NewStubRelayHost() opts := testutils.NewOpts(). SetRelayHost(rh). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(tb testing.TB, ts *testutils.TestServer) { // Note: Upcast to testing.T so we can use t.Run. t := tb.(*testing.T) s1 := ts.NewServer(testutils.NewOpts().SetServiceName("s1")) s2 := ts.NewServer(testutils.NewOpts().SetServiceName("s2")) testutils.RegisterEcho(s1, nil) testutils.RegisterEcho(s2, nil) // Add a handler that always returns an empty payload. testutils.RegisterFunc(s2, "swallow", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { fmt.Println("swallow got", len(args.Arg2)+len(args.Arg3)) return &raw.Res{}, nil }) // Helper to make calls with specific payload sizes. makeCall := func(src, dst *Channel, method string, arg2Size, arg3Size int) { ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() arg2 := testutils.RandBytes(arg2Size) arg3 := testutils.RandBytes(arg3Size) _, _, _, err := raw.Call(ctx, src, ts.HostPort(), dst.ServiceName(), method, arg2, arg3) require.NoError(t, err) } t.Run("verify sent vs received", func(t *testing.T) { makeCall(s1, s2, "swallow", 4*1024, 4*1024) statsMap := rh.Stats().Map() assert.InDelta(t, 8*kb, statsMap["s1->s2::swallow.sent-bytes"], protocolBuffer, "Unexpected sent bytes") assert.InDelta(t, 0, statsMap["s1->s2::swallow.received-bytes"], protocolBuffer, "Unexpected sent bytes") }) t.Run("verify sent and received", func(t *testing.T) { makeCall(s1, s2, "echo", 4*kb, 4*kb) statsMap := rh.Stats().Map() assert.InDelta(t, 8*kb, statsMap["s1->s2::echo.sent-bytes"], protocolBuffer, "Unexpected sent bytes") assert.InDelta(t, 8*kb, statsMap["s1->s2::echo.received-bytes"], protocolBuffer, "Unexpected sent bytes") }) t.Run("verify large payload", func(t *testing.T) { makeCall(s1, s2, "echo", 128*1024, 128*1024) statsMap := rh.Stats().Map() assert.InDelta(t, 256*kb, statsMap["s1->s2::echo.sent-bytes"], protocolBuffer, "Unexpected sent bytes") assert.InDelta(t, 256*kb, statsMap["s1->s2::echo.received-bytes"], protocolBuffer, "Unexpected sent bytes") }) t.Run("verify reverse call", func(t *testing.T) { makeCall(s2, s1, "echo", 0, 64*kb) statsMap := rh.Stats().Map() assert.InDelta(t, 64*kb, statsMap["s2->s1::echo.sent-bytes"], protocolBuffer, "Unexpected sent bytes") assert.InDelta(t, 64*kb, statsMap["s2->s1::echo.received-bytes"], protocolBuffer, "Unexpected sent bytes") }) }) } func TestRelayCallResponse(t *testing.T) { ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() kv := map[string]string{ "foo": "bar", "baz": "qux", } arg2Buf := thriftarg2test.BuildKVBuffer(kv) rh := relaytest.NewStubRelayHost() rh.SetRespFrameFn(func(frame relay.RespFrame) { require.True(t, frame.OK(), "Got unexpected response status") require.Equal(t, Thrift.String(), string(frame.ArgScheme()), "Got unexpected scheme") iter, err := arg2.NewKeyValIterator(frame.Arg2()) require.NoError(t, err, "Got unexpected iterator error") gotKV := make(map[string]string) for ; err == nil; iter, err = iter.Next() { gotKV[string(iter.Key())] = string(iter.Value()) } assert.Equal(t, kv, gotKV, "Got unexpected arg2 in response") }) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh) testutils.WithTestServer(t, opts, func(tb testing.TB, ts *testutils.TestServer) { const ( testMethod = "echo" arg3Data = "arg3-here" ) testutils.RegisterEcho(ts.Server(), nil) client := testutils.NewClient(t, nil /*opts*/) defer client.Close() call, err := client.BeginCall(ctx, ts.HostPort(), ts.ServiceName(), testMethod, &CallOptions{Format: Thrift}) require.NoError(t, err, "BeginCall failed") require.NoError(t, NewArgWriter(call.Arg2Writer()).Write(arg2Buf), "arg2 write failed") require.NoError(t, NewArgWriter(call.Arg3Writer()).Write([]byte(arg3Data)), "arg3 write failed") gotArg2, gotArg3, err := raw.ReadArgsV2(call.Response()) assert.NoError(t, err) assert.Equal(t, string(arg2Buf), string(gotArg2), "arg2 in response does not meet expectation") assert.Equal(t, arg3Data, string(gotArg3), "arg3 in response does not meet expectation") }) } func TestRelayAppendArg2SentBytes(t *testing.T) { tests := []struct { msg string appends map[string]string arg3 []byte wantSentBytes int }{ { msg: "without appends", arg3: []byte("hello, world"), wantSentBytes: 130, }, { msg: "with appends", arg3: []byte("hello, world"), appends: map[string]string{"baz": "qux"}, wantSentBytes: 140, // 130 + 2 bytes size + 3 bytes key + 2 byts size + 3 bytes val = 137 }, { msg: "with large appends that result in fragments", arg3: []byte("hello, world"), appends: map[string]string{ "fee": testutils.RandString(16 * 1024), "fii": testutils.RandString(16 * 1024), "foo": testutils.RandString(16 * 1024), "fum": testutils.RandString(16 * 1024), }, // original data size = 130 // appended arg2 size = 2 bytes number of keys + 4 * (2 bytes key size + 3 bytes key + 2 bytes val size + 16 * 1024 bytes val) // additional frame preamble = 16 bytes header + 1 byte flag + 1 byte checksum type + 4 bytes checksum size + 2 bytes size of remaining arg2 wantSentBytes: 130 + (2+3+2+16*1024)*4 + 16 + 1 + 1 + 4 + 2, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { rh := relaytest.NewStubRelayHost() rh.SetFrameFn(func(f relay.CallFrame, conn *relay.Conn) { for k, v := range tt.appends { f.Arg2Append([]byte(k), []byte(v)) } }) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { rly := ts.Relay() svr := ts.Server() testutils.RegisterEcho(svr, nil) client := testutils.NewClient(t, nil) ctx, cancel := NewContextBuilder(testutils.Timeout(time.Second)). SetFormat(Thrift).Build() defer cancel() sendArgs := &raw.Args{ Arg2: thriftarg2test.BuildKVBuffer(map[string]string{"foo": "bar"}), Arg3: tt.arg3, } recvArg2, recvArg3, _, err := raw.Call(ctx, client, rly.PeerInfo().HostPort, ts.ServiceName(), "echo", sendArgs.Arg2, sendArgs.Arg3) require.NoError(t, err, "Call from %v (%v) to %v (%v) failed", client.ServiceName(), client.PeerInfo().HostPort, ts.ServiceName(), rly.PeerInfo().HostPort) wantArg2 := map[string]string{ "foo": "bar", } for k, v := range tt.appends { wantArg2[k] = v } assert.Equal(t, wantArg2, thriftarg2test.MustReadKVBuffer(t, recvArg2), "Arg2 mismatch") assert.Equal(t, recvArg3, []byte("hello, world"), "Arg3 mismatch") sentBytes := rh.Stats().Map()["testService-client->testService::echo.sent-bytes"] assert.Equal(t, tt.wantSentBytes, sentBytes) }) }) } } func inspectFrames(rh *relaytest.StubRelayHost) chan relay.CallFrame { frameCh := make(chan relay.CallFrame, 1) rh.SetFrameFn(func(f relay.CallFrame, _ *relay.Conn) { frameCh <- testutils.CopyCallFrame(f) }) return frameCh } type relayModifier interface { frameFn(cf relay.CallFrame, _ *relay.Conn) modifyArg2(m map[string]string) map[string]string } type noopRelayModifer struct{} func (nrm *noopRelayModifer) frameFn(_ relay.CallFrame, _ *relay.Conn) {} func (nrm *noopRelayModifer) modifyArg2(m map[string]string) map[string]string { return m } type keyVal struct { key, val string } type arg2KeyValRelayModifier struct { keyValPairs []keyVal } func addFixedKeyVal(kvPairs []keyVal) *arg2KeyValRelayModifier { return &arg2KeyValRelayModifier{ keyValPairs: kvPairs, } } func fillFrameWithArg2(t *testing.T, checksumType ChecksumType, arg1 string, arg2 map[string]string, bytePosFromBoundary int) *arg2KeyValRelayModifier { arg2Key := "foo" arg2Len := 2 // nh for k, v := range arg2 { arg2Len += 2 + len(k) + 2 + len(v) } // Writing an arg adds nh+nk+len(key)+nv+len(val) bytes. calculate the size of val // so that we end at bytePosFromBoundary in the frame. remainingSpaceBeforeChecksum // is the number of bytes from the start of the frame up until the checkumType byte, // just before the checksum itself. const remainingSpaceBeforeChecksum = 65441 valSize := remainingSpaceBeforeChecksum + bytePosFromBoundary - (checksumType.ChecksumSize() + 2 /* nArg1 */ + len(arg1) + arg2Len + 2 /* nk */ + len(arg2Key) + 2 /* nv */) if valSize < 0 { t.Fatalf("can't fill arg2 with key %q and %d bytes remaining", arg2Key, bytePosFromBoundary) } return &arg2KeyValRelayModifier{ keyValPairs: []keyVal{ {key: arg2Key, val: testutils.RandString(valSize)}, }, } } func (rm *arg2KeyValRelayModifier) frameFn(cf relay.CallFrame, _ *relay.Conn) { for _, kv := range rm.keyValPairs { cf.Arg2Append([]byte(kv.key), []byte(kv.val)) } } func (rm *arg2KeyValRelayModifier) modifyArg2(m map[string]string) map[string]string { if m == nil { m = make(map[string]string) } for _, kv := range rm.keyValPairs { m[kv.key] = kv.val } return m } func TestRelayModifyArg2(t *testing.T) { const kb = 1024 checksumTypes := []struct { msg string checksumType ChecksumType }{ {"none", ChecksumTypeNone}, {"crc32", ChecksumTypeCrc32}, {"farmhash", ChecksumTypeFarmhash}, {"crc32c", ChecksumTypeCrc32C}, } modifyTests := []struct { msg string skip string modifier func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier }{ { msg: "no change", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return &noopRelayModifer{} }, }, { msg: "add zero-length key/value", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return addFixedKeyVal([]keyVal{{key: "", val: ""}}) }, }, { msg: "add multiple zero-length key/value", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return addFixedKeyVal([]keyVal{ {"", ""}, {"", ""}, {"", ""}, }) }, }, { msg: "add small key/value", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return addFixedKeyVal([]keyVal{ {"foo", "bar"}, {"baz", "qux"}, }) }, }, { msg: "fill the first frame until 2 bytes remain", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return fillFrameWithArg2(t, cst, arg1, arg2, -2) }, }, { msg: "fill the first frame until 1 byte remain", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return fillFrameWithArg2(t, cst, arg1, arg2, -1) }, }, { msg: "fill the first frame to its boundary", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return fillFrameWithArg2(t, cst, arg1, arg2, 0) }, }, { msg: "fill the first frame to 1 byte over its boundary", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return fillFrameWithArg2(t, cst, arg1, arg2, 1) }, }, { msg: "fill the first frame to 2 bytes over its boundary", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return fillFrameWithArg2(t, cst, arg1, arg2, 2) }, }, { msg: "add large key/value which pushes arg2 into 2nd frame", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return addFixedKeyVal([]keyVal{ {"fee", testutils.RandString(65535)}, }) }, }, { msg: "add large key/value which pushes arg2 into 2nd and 3rd frame", modifier: func(t *testing.T, cst ChecksumType, arg1 string, arg2 map[string]string) relayModifier { return addFixedKeyVal([]keyVal{ {"fee", testutils.RandString(65535)}, {"fi", testutils.RandString(65535)}, }) }, }, } // TODO(cinchurge): we need to cover a combination of the following for the payloads: // - no arg2, small arg2, large arg2 (3 or 4 cases that are close/on the boundary) // - no arg3, small arg3, 16kb arg3, 32kb arg3, 64kb arg3, 128kb arg3, 1mb arg3 // - 2 bytes, 1 byte, and 0 bytes from the frame boundary for both arg2 and arg3 payloadTests := []struct { msg string arg2 map[string]string arg3 []byte }{ { msg: "no payload", arg2: nil, // empty map arg3: []byte{}, }, { // TODO(cinchurge): ideally we'd like to do tests where arg2 is close to and on the // frame boundary, however since the corresponding arg2 size depends on the sizes of arg1 // and the checksum, we're deferring this to a separate change. msg: "no payload + large arg2", arg2: map[string]string{ "foo": testutils.RandString(60000), }, // empty map arg3: []byte{}, }, { msg: "1kB payloads", arg2: map[string]string{ "existingKey": "existingValue", }, arg3: testutils.RandBytes(kb), }, { msg: "16kB payloads", arg2: map[string]string{ "existingKey": "existingValue", }, arg3: testutils.RandBytes(16 * kb), }, { msg: "32kB payloads", arg2: map[string]string{ "existingKey": "existingValue", }, arg3: testutils.RandBytes(32 * kb), }, { msg: "64kB payloads", arg2: map[string]string{ "existingKey": "existingValue", }, arg3: testutils.RandBytes(64 * kb), }, { msg: "128kB payloads", arg2: map[string]string{ "existingKey": "existingValue", }, arg3: testutils.RandBytes(128 * kb), }, { msg: "1MB payloads", arg2: map[string]string{ "existingKey": "existingValue", }, arg3: testutils.RandBytes(1024 * kb), }, } const ( format = Thrift noErrMethod = "EchoVerifyNoErr" errMethod = "EchoVerifyErr" ) appErrTests := []struct { msg string method string wantAppErr bool }{ { msg: "no app error bit", method: noErrMethod, wantAppErr: false, }, { msg: "app error bit", method: errMethod, wantAppErr: true, }, } for _, mt := range modifyTests { for _, csTest := range checksumTypes { // Make calls with different payloads and expected errors. for _, aet := range appErrTests { for _, tt := range payloadTests { t.Run(fmt.Sprintf("%s,checksum=%s,%s,%s", mt.msg, csTest.msg, aet.msg, tt.msg), func(t *testing.T) { modifier := mt.modifier(t, csTest.checksumType, aet.method, tt.arg2) // Create a relay that will modify the frame as per the test. relayHost := relaytest.NewStubRelayHost() relayHost.SetFrameFn(modifier.frameFn) opts := testutils.NewOpts(). SetRelayHost(relayHost). SetRelayOnly(). SetCheckFramePooling() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { // Create a client that uses a specific checksumType. clientOpts := testutils.NewOpts().SetChecksumType(csTest.checksumType) client := ts.NewClient(clientOpts) defer client.Close() // Create a server echo verify endpoints (optionally returning an error). for _, appErrTest := range appErrTests { handler := echoVerifyHandler{ t: t, verifyFormat: format, verifyCaller: client.ServiceName(), verifyMethod: appErrTest.method, appErr: appErrTest.wantAppErr, } ts.Server().Register(raw.Wrap(handler), appErrTest.method) } ctx, cancel := NewContextBuilder(testutils.Timeout(time.Second)). SetFormat(format).Build() defer cancel() arg2Encoded := encodeThriftHeaders(t, tt.arg2) resArg2, resArg3, resp, err := raw.Call(ctx, client, ts.HostPort(), ts.ServiceName(), aet.method, arg2Encoded, tt.arg3) require.NoError(t, err, "%v: Received unexpected error", tt.msg) assert.Equal(t, format, resp.Format(), "%v: Unexpected error format") assert.Equal(t, aet.wantAppErr, resp.ApplicationError(), "%v: Unexpected app error") wantArg2 := modifier.modifyArg2(copyHeaders(tt.arg2)) gotArg2Map := decodeThriftHeaders(t, resArg2) assert.Equal(t, wantArg2, gotArg2Map, "%v: Unexpected arg2 headers", tt.msg) assert.Equal(t, resArg3, tt.arg3, "%v: Unexpected arg3", tt.msg) }) }) } } } } } func TestRelayModifyArg2ShouldFail(t *testing.T) { tests := []struct { msg string arg2 []byte format Format wantErr string }{ { msg: "large arg2, fragmented", arg2: thriftarg2test.BuildKVBuffer(map[string]string{ "fee": testutils.RandString(16 * 1024), "fi": testutils.RandString(16 * 1024), "fo": testutils.RandString(16 * 1024), "fum": testutils.RandString(16 * 1024), }), wantErr: "relay-arg2-modify-failed: fragmented arg2", }, { msg: "non-Thrift call", format: JSON, arg2: thriftarg2test.BuildKVBuffer(map[string]string{ "fee": testutils.RandString(16 * 1024), }), wantErr: "relay-arg2-modify-failed: cannot inspect or modify arg2 for non-Thrift calls", }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { rh := relaytest.NewStubRelayHost() rh.SetFrameFn(func(f relay.CallFrame, conn *relay.Conn) { f.Arg2Append([]byte("foo"), []byte("bar")) }) opts := testutils.NewOpts(). SetRelayOnly(). SetCheckFramePooling(). SetRelayHost(rh). AddLogFilter("Failed to send call with modified arg2.", 1) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { rly := ts.Relay() callee := ts.Server() testutils.RegisterEcho(callee, nil) caller := ts.NewServer(testutils.NewOpts()) testutils.RegisterEcho(caller, nil) baseCtx := context.WithValue(context.Background(), "foo", "bar") ctx, cancel := NewContextBuilder(time.Second).SetConnectBaseContext(baseCtx).Build() defer cancel() require.NoError(t, rly.Ping(ctx, caller.PeerInfo().HostPort)) err := testutils.CallEcho(caller, ts.HostPort(), ts.ServiceName(), &raw.Args{ Format: tt.format, Arg2: tt.arg2, }) require.Error(t, err, "should fail to send call with large arg2") assert.Contains(t, err.Error(), tt.wantErr, "unexpected error") // Even after a failure, a simple call should still suceed (e.g., connection is left in a safe state). err = testutils.CallEcho(caller, ts.HostPort(), ts.ServiceName(), &raw.Args{ Format: Thrift, Arg2: encodeThriftHeaders(t, map[string]string{"key": "value"}), Arg3: testutils.RandBytes(100), }) require.NoError(t, err, "Standard Thrift call should not fail") }) }) } } // echoVerifyHandler is an echo handler with some added verification of // the call metadata (e.g., caller, format). type echoVerifyHandler struct { t testing.TB appErr bool verifyFormat Format verifyCaller string verifyMethod string } func (h echoVerifyHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { assert.Equal(h.t, h.verifyFormat, args.Format, "Unexpected format") assert.Equal(h.t, h.verifyCaller, args.Caller, "Unexpected caller") assert.Equal(h.t, h.verifyMethod, args.Method, "Unexpected method") return &raw.Res{ Arg2: args.Arg2, Arg3: args.Arg3, IsErr: h.appErr, }, nil } func (h echoVerifyHandler) OnError(ctx context.Context, err error) { h.t.Errorf("unexpected OnError: %v", err) } func encodeThriftHeaders(t testing.TB, m map[string]string) []byte { var buf bytes.Buffer require.NoError(t, thrift.WriteHeaders(&buf, m), "Failed to write headers") return buf.Bytes() } func decodeThriftHeaders(t testing.TB, bs []byte) map[string]string { r := bytes.NewReader(bs) m, err := thrift.ReadHeaders(r) require.NoError(t, err, "Failed to read headers") // Ensure there are no remaining bytes left. remaining, err := ioutil.ReadAll(r) require.NoError(t, err, "failed to read from arg2 reader") assert.Empty(t, remaining, "expected no bytes after reading headers") return m } func copyHeaders(m map[string]string) map[string]string { if m == nil { return nil } copied := make(map[string]string, len(m)) for k, v := range m { copied[k] = v } return copied } ================================================ FILE: relay_timer_pool.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "math" "sync" "time" ) type relayTimerTrigger func(items *relayItems, id uint32, isOriginator bool) type relayTimerPool struct { pool sync.Pool trigger relayTimerTrigger verify bool } type relayTimer struct { pool *relayTimerPool // const timer *time.Timer // const active bool // mutated on Start/Stop stopped bool // mutated on Stop released bool // mutated on Get/Release. // Per-timer parameters passed back when the timer is triggered. items *relayItems id uint32 isOriginator bool } func (rt *relayTimer) OnTimer() { rt.verifyNotReleased() items, id, isOriginator := rt.items, rt.id, rt.isOriginator rt.markTimerInactive() rt.pool.trigger(items, id, isOriginator) } func newRelayTimerPool(trigger relayTimerTrigger, verify bool) *relayTimerPool { return &relayTimerPool{ trigger: trigger, verify: verify, } } // Get returns a relay timer that has not started. Timers must be started explicitly // using the Start function. func (tp *relayTimerPool) Get() *relayTimer { timer, ok := tp.pool.Get().(*relayTimer) if ok { timer.released = false return timer } rt := &relayTimer{ pool: tp, } // Go timers are started by default. However, we need to separate creating // the timer and starting the timer for use in the relay code paths. // To make this work without more locks in the relayTimer, we create a Go timer // with a huge timeout so it doesn't run, then stop it so we can start it later. rt.timer = time.AfterFunc(time.Duration(math.MaxInt64), rt.OnTimer) if !rt.timer.Stop() { panic("relayTimer requires timers in stopped state, but failed to stop underlying timer") } return rt } // Put returns a relayTimer back to the pool. func (tp *relayTimerPool) Put(rt *relayTimer) { if tp.verify { // If we are trying to verify correct pool behavior, then we don't release // the timer, and instead ensure no methods are called after being released. return } tp.pool.Put(rt) } // Start starts a timer with the given duration for the specified ID. func (rt *relayTimer) Start(d time.Duration, items *relayItems, id uint32, isOriginator bool) { rt.verifyNotReleased() if rt.active { panic("Tried to start an already-active timer") } rt.active = true rt.stopped = false rt.items = items rt.id = id rt.isOriginator = isOriginator if wasActive := rt.timer.Reset(d); wasActive { panic("relayTimer's underlying timer was Started multiple times without Stop") } } func (rt *relayTimer) markTimerInactive() { rt.active = false rt.items = nil rt.id = 0 rt.items = nil rt.isOriginator = false } // Stop stops the timer and returns whether the timer was stopped. // If the timer has been executed, it returns false, but in all other // cases, it returns true (even if the timer was stopped previously). func (rt *relayTimer) Stop() bool { rt.verifyNotReleased() if rt.stopped { return true } stopped := rt.timer.Stop() if stopped { rt.stopped = true rt.markTimerInactive() } return stopped } // Release releases a timer back to the timer pool. The timer MUST have run or be // stopped before Release is called. func (rt *relayTimer) Release() { rt.verifyNotReleased() if rt.active { panic("only stopped or completed timers can be released") } rt.released = true rt.pool.Put(rt) } func (rt *relayTimer) verifyNotReleased() { if rt.released { panic("Released timer cannot be used") } } ================================================ FILE: reqres.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "github.com/uber/tchannel-go/typed" ) type errReqResWriterStateMismatch struct { state reqResWriterState expectedState reqResWriterState } func (e errReqResWriterStateMismatch) Error() string { return fmt.Sprintf("attempting write outside of expected state, in %v expected %v", e.state, e.expectedState) } type errReqResReaderStateMismatch struct { state reqResReaderState expectedState reqResReaderState } func (e errReqResReaderStateMismatch) Error() string { return fmt.Sprintf("attempting read outside of expected state, in %v expected %v", e.state, e.expectedState) } // reqResWriterState defines the state of a request/response writer type reqResWriterState int const ( reqResWriterPreArg1 reqResWriterState = iota reqResWriterPreArg2 reqResWriterPreArg3 reqResWriterComplete ) //go:generate stringer -type=reqResWriterState // messageForFragment determines which message should be used for the given // fragment type messageForFragment func(initial bool) message // A reqResWriter writes out requests/responses. Exactly which it does is // determined by its messageForFragment function which returns the appropriate // message to use when building an initial or follow-on fragment. type reqResWriter struct { conn *Connection contents *fragmentingWriter mex *messageExchange state reqResWriterState messageForFragment messageForFragment log Logger err error } //go:generate stringer -type=reqResReaderState func (w *reqResWriter) argWriter(last bool, inState reqResWriterState, outState reqResWriterState) (ArgWriter, error) { if w.err != nil { return nil, w.err } if w.state != inState { return nil, w.failed(errReqResWriterStateMismatch{state: w.state, expectedState: inState}) } argWriter, err := w.contents.ArgWriter(last) if err != nil { return nil, w.failed(err) } w.state = outState return argWriter, nil } func (w *reqResWriter) arg1Writer() (ArgWriter, error) { return w.argWriter(false /* last */, reqResWriterPreArg1, reqResWriterPreArg2) } func (w *reqResWriter) arg2Writer() (ArgWriter, error) { return w.argWriter(false /* last */, reqResWriterPreArg2, reqResWriterPreArg3) } func (w *reqResWriter) arg3Writer() (ArgWriter, error) { return w.argWriter(true /* last */, reqResWriterPreArg3, reqResWriterComplete) } // newFragment creates a new fragment for marshaling into func (w *reqResWriter) newFragment(initial bool, checksum Checksum) (*writableFragment, error) { if err := w.mex.checkError(); err != nil { return nil, w.failed(err) } message := w.messageForFragment(initial) // Create the frame frame := w.conn.opts.FramePool.Get() frame.Header.ID = w.mex.msgID frame.Header.messageType = message.messageType() // Write the message into the fragment, reserving flags and checksum bytes wbuf := typed.NewWriteBuffer(frame.Payload[:]) fragment := new(writableFragment) fragment.frame = frame fragment.flagsRef = wbuf.DeferByte() if err := message.write(wbuf); err != nil { return nil, err } wbuf.WriteSingleByte(byte(checksum.TypeCode())) fragment.checksumRef = wbuf.DeferBytes(checksum.Size()) fragment.checksum = checksum fragment.contents = wbuf return fragment, wbuf.Err() } // flushFragment sends a fragment to the peer over the connection func (w *reqResWriter) flushFragment(fragment *writableFragment) error { if w.err != nil { return w.err } frame := fragment.frame frame.Header.SetPayloadSize(uint16(fragment.contents.BytesWritten())) if err := w.mex.checkError(); err != nil { return w.failed(err) } select { case <-w.mex.ctx.Done(): return w.failed(GetContextError(w.mex.ctx.Err())) case <-w.mex.errCh.c: return w.failed(w.mex.errCh.err) case w.conn.sendCh <- frame: return nil } } // failed marks the writer as having failed func (w *reqResWriter) failed(err error) error { w.log.Debugf("writer failed: %v existing err: %v", err, w.err) if w.err != nil { return w.err } w.mex.shutdown() w.err = err return w.err } // reqResReaderState defines the state of a request/response reader type reqResReaderState int const ( reqResReaderPreArg1 reqResReaderState = iota reqResReaderPreArg2 reqResReaderPreArg3 reqResReaderComplete ) // A reqResReader is capable of reading arguments from a request or response object. type reqResReader struct { contents *fragmentingReader mex *messageExchange state reqResReaderState messageForFragment messageForFragment initialFragment *readableFragment previousFragment *readableFragment log Logger err error } // arg1Reader returns an ArgReader to read arg1. func (r *reqResReader) arg1Reader() (ArgReader, error) { return r.argReader(false /* last */, reqResReaderPreArg1, reqResReaderPreArg2) } // arg2Reader returns an ArgReader to read arg2. func (r *reqResReader) arg2Reader() (ArgReader, error) { return r.argReader(false /* last */, reqResReaderPreArg2, reqResReaderPreArg3) } // arg3Reader returns an ArgReader to read arg3. func (r *reqResReader) arg3Reader() (ArgReader, error) { return r.argReader(true /* last */, reqResReaderPreArg3, reqResReaderComplete) } // argReader returns an ArgReader that can be used to read an argument. The // ReadCloser must be closed once the argument has been read. func (r *reqResReader) argReader(last bool, inState reqResReaderState, outState reqResReaderState) (ArgReader, error) { if r.state != inState { return nil, r.failed(errReqResReaderStateMismatch{state: r.state, expectedState: inState}) } argReader, err := r.contents.ArgReader(last) if err != nil { return nil, r.failed(err) } r.state = outState return argReader, nil } // recvNextFragment receives the next fragment from the underlying message exchange. func (r *reqResReader) recvNextFragment(initial bool) (*readableFragment, error) { if r.initialFragment != nil { fragment := r.initialFragment r.initialFragment = nil r.previousFragment = fragment return fragment, nil } // Wait for the appropriate message from the peer message := r.messageForFragment(initial) frame, err := r.mex.recvPeerFrameOfType(message.messageType()) if err != nil { if err, ok := err.(errorMessage); ok { // If we received a serialized error from the other side, then we should go through // the normal doneReading path so stats get updated with this error. r.err = err.AsSystemError() return nil, err } return nil, r.failed(err) } // Parse the message and setup the fragment fragment, err := parseInboundFragment(r.mex.framePool, frame, message) if err != nil { return nil, r.failed(err) } r.previousFragment = fragment return fragment, nil } // releasePreviousFrament releases the last fragment returned by the reader if // it's still around. This operation is idempotent. func (r *reqResReader) releasePreviousFragment() { fragment := r.previousFragment r.previousFragment = nil if fragment != nil { fragment.done() } } // failed indicates the reader failed func (r *reqResReader) failed(err error) error { r.log.Debugf("reader failed: %v existing err: %v", err, r.err) if r.err != nil { return r.err } r.mex.shutdown() r.err = err return r.err } // parseInboundFragment parses an incoming fragment based on the given message func parseInboundFragment(framePool FramePool, frame *Frame, message message) (*readableFragment, error) { rbuf := typed.NewReadBuffer(frame.SizedPayload()) fragment := new(readableFragment) fragment.flags = rbuf.ReadSingleByte() if err := message.read(rbuf); err != nil { return nil, err } fragment.checksumType = ChecksumType(rbuf.ReadSingleByte()) fragment.checksum = rbuf.ReadBytes(fragment.checksumType.ChecksumSize()) fragment.contents = rbuf fragment.onDone = func() { framePool.Release(frame) } return fragment, rbuf.Err() } ================================================ FILE: reqresreaderstate_string.go ================================================ // generated by stringer -type=reqResReaderState; DO NOT EDIT package tchannel import "fmt" const _reqResReaderState_name = "reqResReaderPreArg1reqResReaderPreArg2reqResReaderPreArg3reqResReaderComplete" var _reqResReaderState_index = [...]uint8{0, 19, 38, 57, 77} func (i reqResReaderState) String() string { if i < 0 || i+1 >= reqResReaderState(len(_reqResReaderState_index)) { return fmt.Sprintf("reqResReaderState(%d)", i) } return _reqResReaderState_name[_reqResReaderState_index[i]:_reqResReaderState_index[i+1]] } ================================================ FILE: reqreswriterstate_string.go ================================================ // generated by stringer -type=reqResWriterState; DO NOT EDIT package tchannel import "fmt" const _reqResWriterState_name = "reqResWriterPreArg1reqResWriterPreArg2reqResWriterPreArg3reqResWriterComplete" var _reqResWriterState_index = [...]uint8{0, 19, 38, 57, 77} func (i reqResWriterState) String() string { if i < 0 || i+1 >= reqResWriterState(len(_reqResWriterState_index)) { return fmt.Sprintf("reqResWriterState(%d)", i) } return _reqResWriterState_name[_reqResWriterState_index[i]:_reqResWriterState_index[i+1]] } ================================================ FILE: retry.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "net" "sync" "time" "golang.org/x/net/context" ) // RetryOn represents the types of errors to retry on. type RetryOn int //go:generate stringer -type=RetryOn const ( // RetryDefault is currently the same as RetryConnectionError. RetryDefault RetryOn = iota // RetryConnectionError retries on busy frames, declined frames, and connection errors. RetryConnectionError // RetryNever never retries any errors. RetryNever // RetryNonIdempotent will retry errors that occur before a request has been picked up. // E.g. busy frames and declined frames. // This should be used when making calls to non-idempotent endpoints. RetryNonIdempotent // RetryUnexpected will retry busy frames, declined frames, and unenxpected frames. RetryUnexpected // RetryIdempotent will retry all errors that can be retried. This should be used // for idempotent endpoints. RetryIdempotent ) // RequestState is a global request state that persists across retries. type RequestState struct { // Start is the time at which the request was initiated by the caller of RunWithRetry. Start time.Time // SelectedPeers is a set of host:ports that have been selected previously. SelectedPeers map[string]struct{} // Attempt is 1 for the first attempt, and so on. Attempt int retryOpts *RetryOptions } // RetriableFunc is the type of function that can be passed to RunWithRetry. type RetriableFunc func(context.Context, *RequestState) error func isNetError(err error) bool { // TODO(prashantv): Should TChannel internally these to ErrCodeNetwork before returning // them to the user? _, ok := err.(net.Error) return ok } func getErrCode(err error) SystemErrCode { code := GetSystemErrorCode(err) if isNetError(err) { code = ErrCodeNetwork } return code } // CanRetry returns whether an error can be retried for the given retry option. func (r RetryOn) CanRetry(err error) bool { if r == RetryNever { return false } if r == RetryDefault { r = RetryConnectionError } code := getErrCode(err) if code == ErrCodeBusy || code == ErrCodeDeclined { return true } // Never retry bad requests, since it will probably cause another bad request. if code == ErrCodeBadRequest { return false } switch r { case RetryConnectionError: return code == ErrCodeNetwork case RetryUnexpected: return code == ErrCodeUnexpected case RetryIdempotent: return true } return false } // RetryOptions are the retry options used to configure RunWithRetry. type RetryOptions struct { // MaxAttempts is the maximum number of calls and retries that will be made. // If this is 0, the default number of attempts (5) is used. MaxAttempts int // RetryOn is the types of errors to retry on. RetryOn RetryOn // TimeoutPerAttempt is the per-retry timeout to use. // If this is zero, then the original timeout is used. TimeoutPerAttempt time.Duration } var defaultRetryOptions = &RetryOptions{ MaxAttempts: 5, } var requestStatePool = sync.Pool{ New: func() interface{} { return &RequestState{} }, } func getRetryOptions(ctx context.Context) *RetryOptions { params := getTChannelParams(ctx) if params == nil { return defaultRetryOptions } opts := params.retryOptions if opts == nil { return defaultRetryOptions } if opts.MaxAttempts == 0 { opts.MaxAttempts = defaultRetryOptions.MaxAttempts } return opts } // HasRetries will return true if there are more retries left. func (rs *RequestState) HasRetries(err error) bool { if rs == nil { return false } rOpts := rs.retryOpts return rs.Attempt < rOpts.MaxAttempts && rOpts.RetryOn.CanRetry(err) } // SinceStart returns the time since the start of the request. If there is no request state, // then the fallback is returned. func (rs *RequestState) SinceStart(now time.Time, fallback time.Duration) time.Duration { if rs == nil { return fallback } return now.Sub(rs.Start) } // PrevSelectedPeers returns the previously selected peers for this request. func (rs *RequestState) PrevSelectedPeers() map[string]struct{} { if rs == nil { return nil } return rs.SelectedPeers } // AddSelectedPeer adds a given peer to the set of selected peers. func (rs *RequestState) AddSelectedPeer(hostPort string) { if rs == nil { return } host := getHost(hostPort) if rs.SelectedPeers == nil { rs.SelectedPeers = map[string]struct{}{ hostPort: {}, host: {}, } } else { rs.SelectedPeers[hostPort] = struct{}{} rs.SelectedPeers[host] = struct{}{} } } // RetryCount returns the retry attempt this is. Essentially, Attempt - 1. func (rs *RequestState) RetryCount() int { if rs == nil { return 0 } return rs.Attempt - 1 } // RunWithRetry will take a function that makes the TChannel call, and will // rerun it as specifed in the RetryOptions in the Context. func (ch *Channel) RunWithRetry(runCtx context.Context, f RetriableFunc) error { var err error opts := getRetryOptions(runCtx) rs := ch.getRequestState(opts) defer requestStatePool.Put(rs) for i := 0; i < opts.MaxAttempts; i++ { rs.Attempt++ if opts.TimeoutPerAttempt == 0 { err = f(runCtx, rs) } else { attemptCtx, cancel := context.WithTimeout(runCtx, opts.TimeoutPerAttempt) err = f(attemptCtx, rs) cancel() } if err == nil { return nil } if !opts.RetryOn.CanRetry(err) { if ch.log.Enabled(LogLevelInfo) { ch.log.WithFields(ErrField(err)).Info("Failed after non-retriable error.") } return err } ch.log.WithFields( ErrField(err), LogField{"attempt", rs.Attempt}, LogField{"maxAttempts", opts.MaxAttempts}, ).Info("Retrying request after retryable error.") } // Too many retries, return the last error return err } func (ch *Channel) getRequestState(retryOpts *RetryOptions) *RequestState { rs := requestStatePool.Get().(*RequestState) *rs = RequestState{ Start: ch.timeNow(), retryOpts: retryOpts, } return rs } // getHost returns the host part of a host:port. If no ':' is found, it returns the // original string. Note: This hand-rolled loop is faster than using strings.IndexByte. func getHost(hostPort string) string { for i := 0; i < len(hostPort); i++ { if hostPort[i] == ':' { return hostPort[:i] } } return hostPort } ================================================ FILE: retry_request_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) func TestRequestStateRetry(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(raw.Wrap(newTestHandler(t)), "echo") closedHostPorts := make([]string, 4) for i := range closedHostPorts { hostPort, close := testutils.GetAcceptCloseHostPort(t) defer close() closedHostPorts[i] = hostPort } // Since we close connections remotely, there will be some warnings that we can ignore. opts := testutils.NewOpts().DisableLogVerification() client := ts.NewClient(opts) defer client.Close() counter := 0 sc := client.GetSubChannel(ts.Server().ServiceName()) err := client.RunWithRetry(ctx, func(ctx context.Context, rs *RequestState) error { defer func() { counter++ }() expectedPeers := counter if expectedPeers > 0 { // An entry is also added for each host. expectedPeers++ } assert.Equal(t, expectedPeers, len(rs.SelectedPeers), "SelectedPeers should not be reused") if counter < 4 { client.Peers().Add(closedHostPorts[counter]) } else { client.Peers().Add(ts.HostPort()) } _, err := raw.CallV2(ctx, sc, raw.CArgs{ Method: "echo", CallOptions: &CallOptions{RequestState: rs}, }) return err }) assert.NoError(t, err, "RunWithRetry should succeed") assert.Equal(t, 5, counter, "RunWithRetry should retry 5 times") }) } ================================================ FILE: retry_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "net" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) func createFuncToRetry(t *testing.T, errors ...error) (RetriableFunc, *int) { i := 0 return func(_ context.Context, rs *RequestState) error { defer func() { i++ }() if i >= len(errors) { t.Fatalf("Retry function has no error to return for this call") } assert.Equal(t, i+1, rs.Attempt, "Attempt count mismatch") err := errors[i] return err }, &i } type testErrors struct { Busy error Declined error Timeout error Network error Connection error BadRequest error Unexpected error Cancelled error all []error } func getTestErrors() testErrors { errs := testErrors{ Busy: ErrServerBusy, Declined: ErrChannelClosed, Timeout: ErrTimeout, Network: NewSystemError(ErrCodeNetwork, "fake network error"), Connection: net.UnknownNetworkError("fake connection error"), BadRequest: ErrTimeoutRequired, Unexpected: NewSystemError(ErrCodeUnexpected, "fake unexpected error"), Cancelled: NewSystemError(ErrCodeCancelled, "fake cancelled error"), } errs.all = []error{errs.Busy, errs.Declined, errs.Timeout, errs.Network, errs.Connection, errs.BadRequest, errs.Unexpected, errs.Cancelled} return errs } func TestCanRetry(t *testing.T) { e := getTestErrors() tests := []struct { RetryOn RetryOn RetryOK []error }{ {RetryNever, nil}, {RetryDefault, []error{e.Busy, e.Declined, e.Network, e.Connection}}, {RetryConnectionError, []error{e.Busy, e.Declined, e.Network, e.Connection}}, {RetryNonIdempotent, []error{e.Busy, e.Declined}}, {RetryUnexpected, []error{e.Busy, e.Declined, e.Unexpected}}, {RetryIdempotent, []error{e.Busy, e.Declined, e.Timeout, e.Network, e.Connection, e.Unexpected, e.Cancelled}}, } for _, tt := range tests { retryOK := make(map[error]bool) for _, err := range tt.RetryOK { retryOK[err] = true } for _, err := range e.all { expectOK := retryOK[err] assert.Equal(t, expectOK, tt.RetryOn.CanRetry(err), "%v.CanRetry(%v) expected %v", tt.RetryOn, err, expectOK) } } } func TestNoRetry(t *testing.T) { ch := testutils.NewClient(t, nil) defer ch.Close() e := getTestErrors() retryOpts := &RetryOptions{RetryOn: RetryNever} for _, fErr := range e.all { ctx, cancel := NewContextBuilder(time.Second).SetRetryOptions(retryOpts).Build() defer cancel() f, counter := createFuncToRetry(t, fErr) err := ch.RunWithRetry(ctx, f) assert.Equal(t, fErr, err) assert.Equal(t, 1, *counter, "f should not be retried when retried are disabled") } } func TestRetryTillMaxAttempts(t *testing.T) { ch := testutils.NewClient(t, nil) defer ch.Close() setErr := ErrServerBusy runTest := func(maxAttempts, numErrors, expectCounter int, expectErr error) { retryOpts := &RetryOptions{MaxAttempts: maxAttempts} ctx, cancel := NewContextBuilder(time.Second).SetRetryOptions(retryOpts).Build() defer cancel() var errors []error for i := 0; i < numErrors; i++ { errors = append(errors, setErr) } errors = append(errors, nil) f, counter := createFuncToRetry(t, errors...) err := ch.RunWithRetry(ctx, f) assert.Equal(t, expectErr, err, "unexpected result for maxAttempts = %v numErrors = %v", maxAttempts, numErrors) assert.Equal(t, expectCounter, *counter, "expected f to be retried %v times with maxAttempts = %v numErrors = %v", expectCounter, maxAttempts, numErrors) } for numAttempts := 1; numAttempts < 5; numAttempts++ { for numErrors := 0; numErrors < numAttempts+3; numErrors++ { var expectErr error if numErrors >= numAttempts { expectErr = setErr } expectCount := numErrors + 1 if expectCount > numAttempts { expectCount = numAttempts } runTest(numAttempts, numErrors, expectCount, expectErr) } } } func TestRetrySubContextNoTimeoutPerAttempt(t *testing.T) { e := getTestErrors() ctx, cancel := NewContext(time.Second) defer cancel() ch := testutils.NewClient(t, nil) defer ch.Close() counter := 0 ch.RunWithRetry(ctx, func(sctx context.Context, _ *RequestState) error { counter++ assert.Equal(t, ctx, sctx, "Sub-context should be the same") return e.Busy }) assert.Equal(t, 5, counter, "RunWithRetry did not run f enough times") } func TestRetrySubContextTimeoutPerAttempt(t *testing.T) { e := getTestErrors() ctx, cancel := NewContextBuilder(time.Second). SetTimeoutPerAttempt(time.Millisecond).Build() defer cancel() ch := testutils.NewClient(t, nil) defer ch.Close() var lastDeadline time.Time counter := 0 ch.RunWithRetry(ctx, func(sctx context.Context, _ *RequestState) error { counter++ assert.NotEqual(t, ctx, sctx, "Sub-context should be different") deadline, _ := sctx.Deadline() assert.True(t, deadline.After(lastDeadline), "Deadline is invalid") lastDeadline = deadline overallDeadline, _ := ctx.Deadline() assert.True(t, overallDeadline.After(deadline), "Deadline is invalid") return e.Busy }) assert.Equal(t, 5, counter, "RunWithRetry did not run f enough times") } func TestRetryNetConnect(t *testing.T) { e := getTestErrors() ch := testutils.NewClient(t, nil) defer ch.Close() ctx, cancel := NewContext(time.Second) defer cancel() closedAddr := testutils.GetClosedHostPort(t) listenC, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "Listen failed") defer listenC.Close() counter := 0 f := func(ctx context.Context, rs *RequestState) error { counter++ if !rs.HasRetries(e.Connection) { c, err := net.Dial("tcp", listenC.Addr().String()) if err == nil { c.Close() } return err } _, err := net.Dial("tcp", closedAddr) return err } assert.NoError(t, ch.RunWithRetry(ctx, f), "RunWithRetry should succeed") assert.Equal(t, 5, counter, "RunWithRetry should have run f 5 times") } func TestRequestStateSince(t *testing.T) { baseTime := time.Date(2015, 1, 2, 3, 4, 5, 6, time.UTC) tests := []struct { requestState *RequestState now time.Time fallback time.Duration expected time.Duration }{ { requestState: nil, fallback: 3 * time.Millisecond, expected: 3 * time.Millisecond, }, { requestState: &RequestState{Start: baseTime}, now: baseTime.Add(7 * time.Millisecond), fallback: 5 * time.Millisecond, expected: 7 * time.Millisecond, }, } for _, tt := range tests { got := tt.requestState.SinceStart(tt.now, tt.fallback) assert.Equal(t, tt.expected, got, "%+v.SinceStart(%v, %v) expected %v got %v", tt.requestState, tt.now, tt.fallback, tt.expected, got) } } ================================================ FILE: retryon_string.go ================================================ // generated by stringer -type=RetryOn; DO NOT EDIT package tchannel import "fmt" const _RetryOn_name = "RetryDefaultRetryConnectionErrorRetryNeverRetryNonIdempotentRetryUnexpectedRetryIdempotent" var _RetryOn_index = [...]uint8{0, 12, 32, 42, 60, 75, 90} func (i RetryOn) String() string { if i < 0 || i+1 >= RetryOn(len(_RetryOn_index)) { return fmt.Sprintf("RetryOn(%d)", i) } return _RetryOn_name[_RetryOn_index[i]:_RetryOn_index[i+1]] } ================================================ FILE: root_peer_list.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import "sync" // RootPeerList is the root peer list which is only used to connect to // peers and share peers between subchannels. type RootPeerList struct { sync.RWMutex channel Connectable onPeerStatusChanged func(*Peer) peersByHostPort map[string]*Peer } func newRootPeerList(ch Connectable, onPeerStatusChanged func(*Peer)) *RootPeerList { return &RootPeerList{ channel: ch, onPeerStatusChanged: onPeerStatusChanged, peersByHostPort: make(map[string]*Peer), } } // newChild returns a new isolated peer list that shares the underlying peers // with the root peer list. func (l *RootPeerList) newChild() *PeerList { return newPeerList(l) } // Add adds a peer to the root peer list if it does not exist, or return // an existing peer if it exists. func (l *RootPeerList) Add(hostPort string) *Peer { l.RLock() if p, ok := l.peersByHostPort[hostPort]; ok { l.RUnlock() return p } l.RUnlock() l.Lock() defer l.Unlock() if p, ok := l.peersByHostPort[hostPort]; ok { return p } var p *Peer // To avoid duplicate connections, only the root list should create new // peers. All other lists should keep refs to the root list's peers. p = newPeer(l.channel, hostPort, l.onPeerStatusChanged, l.onClosedConnRemoved) l.peersByHostPort[hostPort] = p return p } // GetOrAdd returns a peer for the given hostPort, creating one if it doesn't yet exist. func (l *RootPeerList) GetOrAdd(hostPort string) *Peer { peer, ok := l.Get(hostPort) if ok { return peer } return l.Add(hostPort) } // Get returns a peer for the given hostPort if it exists. func (l *RootPeerList) Get(hostPort string) (*Peer, bool) { l.RLock() p, ok := l.peersByHostPort[hostPort] l.RUnlock() return p, ok } func (l *RootPeerList) onClosedConnRemoved(peer *Peer) { hostPort := peer.HostPort() p, ok := l.Get(hostPort) if !ok { // It's possible that multiple connections were closed and removed at the same time, // so multiple goroutines might be removing the peer from the root peer list. return } if p.canRemove() { l.Lock() delete(l.peersByHostPort, hostPort) l.Unlock() l.channel.Logger().WithFields( LogField{"remoteHostPort", hostPort}, ).Debug("Removed peer from root peer list.") } } // Copy returns a map of the peer list. This method should only be used for testing. func (l *RootPeerList) Copy() map[string]*Peer { l.RLock() defer l.RUnlock() listCopy := make(map[string]*Peer) for k, v := range l.peersByHostPort { listCopy[k] = v } return listCopy } ================================================ FILE: scripts/install-thrift.sh ================================================ #!/bin/bash set -euo pipefail if [ -z "${1}" ]; then echo "usage: ${0} installDirPath" >&2 exit 1 fi BIN_FILE="thrift-1" TAR_FILE="${BIN_FILE}-$(uname -s | tr '[:upper:]' '[:lower:]')-$(uname -m).tar.gz" TAR_LOCATION="https://github.com/uber/tchannel-go/releases/download/thrift-v1.0.0-dev/${TAR_FILE}" mkdir -p "${1}" cd "${1}" wget "${TAR_LOCATION}" tar xzf "${TAR_FILE}" rm -f "${TAR_FILE}" mv "${BIN_FILE}" "thrift" ================================================ FILE: scripts/vbumper/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // vbumper helps bump version numbers in the repository and in the CHANGELOG. package main import ( "bytes" "errors" "flag" "fmt" "html/template" "io/ioutil" "log" "os" "os/exec" "strings" "time" ) var ( _changelogFile = flag.String("changelog-file", "CHANGELOG.md", "Filename of the changelog file") _versionFile = flag.String("version-file", "version.go", "Filename of where the version information is stored") _version = flag.String("version", "", "Version to mention in changelog and version.go") _versionDate = flag.String("version-date", "", "Date to use in the changelog, by default the current date") _skipChangelog = flag.Bool("skip-changelog", false, "Skip updating the changelog") ) func main() { *_versionDate = time.Now().Format("2006-01-02") flag.Parse() if *_version == "" { log.Fatal("Please specify the version to release using --version") } *_version = strings.TrimPrefix(*_version, "v") prevVersion, err := updateChangelog() if err != nil { log.Fatal("failed to update changelog", err) } if err := updateVersion(prevVersion); err != nil { log.Fatal("failed to update version", err) } } func updateVersion(prevVersion string) error { versionBytes, err := ioutil.ReadFile(*_versionFile) if err != nil { return err } newContents := insertNewVersion(string(versionBytes), prevVersion, *_version) return ioutil.WriteFile(*_versionFile, []byte(newContents), 0666) } func insertNewVersion(contents, prevVersion, newVersion string) string { // Find the version string in the file versionStart := strings.Index(contents, prevVersion) versionLine := contents[versionStart:] versionEnd := strings.Index(versionLine, `"`) + versionStart return contents[:versionStart] + newVersion + contents[versionEnd:] } func updateChangelog() (oldVersion string, _ error) { changelogBytes, err := ioutil.ReadFile(*_changelogFile) if err != nil { return "", err } newLog, oldVersion, err := insertNewChangelog(string(changelogBytes)) if err != nil { return "", err } newLog, err = insertChangesLink(newLog, oldVersion, *_version) if err != nil { return "", err } if *_skipChangelog { return oldVersion, nil } return oldVersion, ioutil.WriteFile(*_changelogFile, []byte(newLog), 0666) } func insertNewChangelog(contents string) (string, string, error) { prevVersionHeader := strings.Index(contents, "\n## [") if prevVersionHeader < 0 { return "", "", errors.New("failed to find version header in changelog") } // Skip the newline prevVersionHeader++ versionLine := contents[prevVersionHeader:] prevVersionEnd := strings.Index(versionLine, "]") prevVersion := strings.TrimSpace(versionLine[4:prevVersionEnd]) // The version tag has a "v" prefix. newChanges, err := getNewChangelog("v" + prevVersion) if err != nil { return "", "", err } newContents := contents[:prevVersionHeader] + newChanges + contents[prevVersionHeader:] return newContents, prevVersion, nil } func getNewChangelog(prevVersion string) (string, error) { changes, err := getChanges(prevVersion) if err != nil { return "", err } if len(changes) == 0 { changes = []string{"No changes yet"} } buf := &bytes.Buffer{} _changeTmpl.Execute(buf, struct { Version string Date string Changes []string }{ Version: *_version, Date: *_versionDate, Changes: changes, }) return buf.String(), nil } var _changeTmpl = template.Must(template.New("changelog").Parse( `## [{{ .Version }}] - {{ .Date }} ### Changed {{ range .Changes }} * {{ . -}} {{ end }} `)) func getChanges(prevVersion string) ([]string, error) { cmd := exec.Command("git", "log", "--format=%s", "--no-merges", prevVersion+"..HEAD") cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { return nil, err } lines := strings.Split(string(out), "\n") newLines := make([]string, 0, len(lines)) for _, line := range lines { line = strings.TrimSpace(line) if line == "" { continue } newLines = append(newLines, line) } return newLines, nil } func insertChangesLink(contents, prevVersion, version string) (string, error) { linksMarker := strings.Index(contents, "(Version Links)") if linksMarker == -1 { return "", errors.New("failed to find marker for version links section") } newLine := strings.IndexByte(contents[linksMarker:], '\n') if newLine < 0 { return "", errors.New("failed to find newline after version links section") } insertAt := linksMarker + newLine + 1 linkBlock := fmt.Sprintf("[%v]: %v\n", version, getChangesLink(prevVersion, version)) newContents := contents[:insertAt] + linkBlock + contents[insertAt:] return newContents, nil } func getChangesLink(prevVersion, curVersion string) string { // Example link: // https://github.com/uber/tchannel-go/compare/v1.8.0...v1.8.1 return fmt.Sprintf("https://github.com/uber/tchannel-go/compare/v%v...v%v", prevVersion, curVersion) } ================================================ FILE: sockio_bsd.go ================================================ // Copyright (c) 2020 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:build aix || dragonfly || freebsd || netbsd || openbsd || solaris // +build aix dragonfly freebsd netbsd openbsd solaris package tchannel import "golang.org/x/sys/unix" func getSendQueueLen(fd uintptr) (int, error) { return unix.IoctlGetInt(int(fd), unix.TIOCOUTQ) } ================================================ FILE: sockio_darwin.go ================================================ // Copyright (c) 2020 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:build darwin // +build darwin package tchannel import "golang.org/x/sys/unix" func getSendQueueLen(fd uintptr) (int, error) { // https://www.unix.com/man-page/osx/2/getsockopt/ return unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_NWRITE) } ================================================ FILE: sockio_linux.go ================================================ // Copyright (c) 2020 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. //go:build linux // +build linux package tchannel import "golang.org/x/sys/unix" func getSendQueueLen(fd uintptr) (int, error) { // https://linux.die.net/man/7/tcp return unix.IoctlGetInt(int(fd), unix.SIOCOUTQ) } ================================================ FILE: sockio_non_unix.go ================================================ // Copyright (c) 2020 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Opposite of sockio_unix.go //go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !solaris // +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris package tchannel func (c *Connection) sendBufSize() (sendBufUsage int, sendBufSize int, _ error) { return -1, -1, errNoSyscallConn } ================================================ FILE: sockio_unix.go ================================================ // Copyright (c) 2020 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Match the golang/sys unix file, https://github.com/golang/sys/blob/master/unix/syscall_unix.go#L5 //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris package tchannel import ( "go.uber.org/multierr" "golang.org/x/sys/unix" ) func (c *Connection) sendBufSize() (sendBufUsage int, sendBufSize int, _ error) { sendBufSize = -1 sendBufUsage = -1 if c.sysConn == nil { return sendBufUsage, sendBufSize, errNoSyscallConn } var sendBufLenErr, sendBufLimErr error errs := c.sysConn.Control(func(fd uintptr) { sendBufUsage, sendBufLenErr = getSendQueueLen(fd) sendBufSize, sendBufLimErr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) }) errs = multierr.Append(errs, sendBufLimErr) errs = multierr.Append(errs, sendBufLenErr) return sendBufUsage, sendBufSize, errs } ================================================ FILE: stats/metrickey.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package stats import ( "bytes" "strings" "sync" ) // DefaultMetricPrefix is the default mapping for metrics to statsd keys. // It uses a "tchannel" prefix for all stats. func DefaultMetricPrefix(name string, tags map[string]string) string { return MetricWithPrefix("tchannel.", name, tags) } var bufPool = sync.Pool{ New: func() interface{} { return &bytes.Buffer{} }, } // MetricWithPrefix is the default mapping for metrics to statsd keys. func MetricWithPrefix(prefix, name string, tags map[string]string) string { buf := bufPool.Get().(*bytes.Buffer) buf.Reset() if prefix != "" { buf.WriteString(prefix) } buf.WriteString(name) addKeys := make([]string, 0, 5) switch { case strings.HasPrefix(name, "outbound"): addKeys = append(addKeys, "service", "target-service", "target-endpoint") if strings.HasPrefix(name, "outbound.calls.retries") { addKeys = append(addKeys, "retry-count") } case strings.HasPrefix(name, "inbound"): addKeys = append(addKeys, "calling-service", "service", "endpoint") } for _, k := range addKeys { buf.WriteByte('.') v, ok := tags[k] if ok { writeClean(buf, v) } else { buf.WriteString("no-") buf.WriteString(k) } } m := buf.String() bufPool.Put(buf) return m } // writeClean writes v, after replacing special characters [{}/\\:\s.] with '-' func writeClean(buf *bytes.Buffer, v string) { for i := 0; i < len(v); i++ { c := v[i] switch c { case '{', '}', '/', '\\', ':', '.', ' ', '\t', '\r', '\n': buf.WriteByte('-') default: buf.WriteByte(c) } } } ================================================ FILE: stats/metrickey_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package stats import ( "bytes" "testing" "github.com/stretchr/testify/assert" ) func TestDefaultMetricPrefix(t *testing.T) { outboundTags := map[string]string{ "service": "callerS", "target-service": "targetS", "target-endpoint": "targetE", "retry-count": "retryN", } inboundTags := map[string]string{ "service": "targetS", "endpoint": "targetE", "calling-service": "callerS", } tests := []struct { name string tags map[string]string expected string }{ { name: "outbound.calls.sent", tags: outboundTags, expected: "tchannel.outbound.calls.sent.callerS.targetS.targetE", }, { name: "outbound.calls.retries", tags: outboundTags, expected: "tchannel.outbound.calls.retries.callerS.targetS.targetE.retryN", }, { name: "inbound.calls.recvd", tags: inboundTags, expected: "tchannel.inbound.calls.recvd.callerS.targetS.targetE", }, { name: "inbound.calls.recvd", tags: nil, expected: "tchannel.inbound.calls.recvd.no-calling-service.no-service.no-endpoint", }, } for _, tt := range tests { assert.Equal(t, tt.expected, DefaultMetricPrefix(tt.name, tt.tags), "DefaultMetricPrefix(%q, %v) failed", tt.name, tt.tags) } } func TestClean(t *testing.T) { tests := []struct { key string expected string }{ {"metric", "metric"}, {"met:ric", "met-ric"}, {"met{}ric", "met--ric"}, {"\\metric", "-metric"}, {"/metric", "-metric"}, {" met.ric ", "--met-ric--"}, } for _, tt := range tests { buf := &bytes.Buffer{} writeClean(buf, tt.key) assert.Equal(t, tt.expected, buf.String(), "clean(%q) failed", tt.key) } } func BenchmarkMetricPrefix(b *testing.B) { outboundTags := map[string]string{ "service": "callerS", "target-service": "targetS", "target-endpoint": "targetE", "retry-count": "retryN", } for i := 0; i < b.N; i++ { MetricWithPrefix("", "outbound.calls.retries", outboundTags) DefaultMetricPrefix("outbound.calls.retries", outboundTags) } } ================================================ FILE: stats/statsdreporter.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package stats import ( "time" "github.com/cactus/go-statsd-client/statsd" "github.com/uber/tchannel-go" ) const samplingRate = 1.0 // MetricKey is called to generate the statsd key for a given metric and tags. var MetricKey = DefaultMetricPrefix type statsdReporter struct { client statsd.Statter } // NewStatsdReporter returns a StatsReporter that reports to statsd on the given addr. func NewStatsdReporter(addr, prefix string) (tchannel.StatsReporter, error) { client, err := statsd.NewBufferedClient(addr, prefix, time.Second, 0) if err != nil { return nil, err } return NewStatsdReporterClient(client), nil } // NewStatsdReporterClient returns a StatsReporter that reports stats to the given client. func NewStatsdReporterClient(client statsd.Statter) tchannel.StatsReporter { return &statsdReporter{client} } func (r *statsdReporter) IncCounter(name string, tags map[string]string, value int64) { // TODO(prashant): Deal with errors in the client. r.client.Inc(MetricKey(name, tags), value, samplingRate) } func (r *statsdReporter) UpdateGauge(name string, tags map[string]string, value int64) { r.client.Gauge(MetricKey(name, tags), value, samplingRate) } func (r *statsdReporter) RecordTimer(name string, tags map[string]string, d time.Duration) { r.client.TimingDuration(MetricKey(name, tags), d, samplingRate) } ================================================ FILE: stats/tally.go ================================================ package stats import ( "sync" "time" "github.com/uber/tchannel-go" "github.com/uber-go/tally" ) type wrapper struct { sync.RWMutex scope tally.Scope byTags map[knownTags]*taggedScope } type knownTags struct { dest string source string procedure string retryCount string } type taggedScope struct { sync.RWMutex scope tally.Scope // already tagged with some set of tags counters map[string]tally.Counter gauges map[string]tally.Gauge timers map[string]tally.Timer } // NewTallyReporter takes a tally.Scope and wraps it so it ca be used as a // StatsReporter. The list of metrics emitted is documented on: // https://tchannel.readthedocs.io/en/latest/metrics/ // The metrics emitted are similar to YARPC, the tags emitted are: // source, dest, procedure, and retry-count. func NewTallyReporter(scope tally.Scope) tchannel.StatsReporter { return &wrapper{ scope: scope, byTags: make(map[knownTags]*taggedScope), } } func (w *wrapper) IncCounter(name string, tags map[string]string, value int64) { ts := w.getTaggedScope(tags) ts.getCounter(name).Inc(value) } func (w *wrapper) UpdateGauge(name string, tags map[string]string, value int64) { ts := w.getTaggedScope(tags) ts.getGauge(name).Update(float64(value)) } func (w *wrapper) RecordTimer(name string, tags map[string]string, d time.Duration) { ts := w.getTaggedScope(tags) ts.getTimer(name).Record(d) } func (w *wrapper) getTaggedScope(tags map[string]string) *taggedScope { kt := convertTags(tags) w.RLock() ts, ok := w.byTags[kt] w.RUnlock() if ok { return ts } w.Lock() defer w.Unlock() // Always double-check under the write-lock. if ts, ok := w.byTags[kt]; ok { return ts } ts = &taggedScope{ scope: w.scope.Tagged(kt.tallyTags()), counters: make(map[string]tally.Counter), gauges: make(map[string]tally.Gauge), timers: make(map[string]tally.Timer), } w.byTags[kt] = ts return ts } func convertTags(tags map[string]string) knownTags { if ts, ok := tags["target-service"]; ok { // Outbound call. return knownTags{ dest: ts, source: tags["service"], procedure: tags["target-endpoint"], retryCount: tags["retry-count"], } } if cs, ok := tags["calling-service"]; ok { // Inbound call. return knownTags{ dest: tags["service"], source: cs, procedure: tags["endpoint"], retryCount: tags["retry-count"], } } // TChannel doesn't use any other tags, so ignore all others for now. return knownTags{} } // Create a sub-scope for this set of known tags. func (kt knownTags) tallyTags() map[string]string { tallyTags := make(map[string]string, 5) if kt.dest != "" { tallyTags["dest"] = kt.dest } if kt.source != "" { tallyTags["source"] = kt.source } if kt.procedure != "" { tallyTags["procedure"] = kt.procedure } if kt.retryCount != "" { tallyTags["retry-count"] = kt.retryCount } return tallyTags } func (ts *taggedScope) getCounter(name string) tally.Counter { ts.RLock() counter, ok := ts.counters[name] ts.RUnlock() if ok { return counter } ts.Lock() defer ts.Unlock() // No double-check under the lock, as overwriting the counter has // no impact. counter = ts.scope.Counter(name) ts.counters[name] = counter return counter } func (ts *taggedScope) getGauge(name string) tally.Gauge { ts.RLock() gauge, ok := ts.gauges[name] ts.RUnlock() if ok { return gauge } ts.Lock() defer ts.Unlock() // No double-check under the lock, as overwriting the counter has // no impact. gauge = ts.scope.Gauge(name) ts.gauges[name] = gauge return gauge } func (ts *taggedScope) getTimer(name string) tally.Timer { ts.RLock() timer, ok := ts.timers[name] ts.RUnlock() if ok { return timer } ts.Lock() defer ts.Unlock() // No double-check under the lock, as overwriting the counter has // no impact. timer = ts.scope.Timer(name) ts.timers[name] = timer return timer } ================================================ FILE: stats/tally_test.go ================================================ package stats import ( "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/uber-go/tally" "github.com/uber/tchannel-go/testutils" ) func TestConvertTags(t *testing.T) { tests := []struct { tags map[string]string want map[string]string }{ { tags: nil, want: map[string]string{}, }, { // unknown tags are ignored. tags: map[string]string{"foo": "bar"}, want: map[string]string{}, }, { // Outbound call tags: map[string]string{ "target-service": "tsvc", "service": "foo", "target-endpoint": "te", "retry-count": "4", // ignored tag "foo": "bar", }, want: map[string]string{ "dest": "tsvc", "source": "foo", "procedure": "te", "retry-count": "4", }, }, { // Inbound call tags: map[string]string{ "service": "foo", "calling-service": "bar", "endpoint": "ep", }, want: map[string]string{ "dest": "foo", "source": "bar", "procedure": "ep", }, }, } for _, tt := range tests { t.Run(fmt.Sprint(tt.tags), func(t *testing.T) { got := convertTags(tt.tags) assert.Equal(t, tt.want, got.tallyTags()) }) } } func TestNewTallyReporter(t *testing.T) { want := tally.NewTestScope("" /* prefix */, nil /* tags */) scope := tally.NewTestScope("" /* prefix */, nil /* tags */) wrapped := NewTallyReporter(scope) for i := 0; i < 10; i++ { wrapped.IncCounter("outbound.calls", map[string]string{ "target-service": "tsvc", "service": "foo", "target-endpoint": "te", }, 3) want.Tagged(map[string]string{ "dest": "tsvc", "source": "foo", "procedure": "te", }).Counter("outbound.calls").Inc(3) wrapped.UpdateGauge("num-connections", map[string]string{ "service": "foo", }, 3) want.Gauge("num-connections").Update(3) wrapped.RecordTimer("inbound.call.latency", map[string]string{ "service": "foo", "calling-service": "bar", "endpoint": "ep", }, time.Second) want.Tagged(map[string]string{ "dest": "foo", "source": "bar", "procedure": "ep", }).Timer("inbound.call.latency").Record(time.Second) } assert.Equal(t, want.Snapshot(), scope.Snapshot()) } func TestTallyIntegration(t *testing.T) { clientScope := tally.NewTestScope("" /* prefix */, nil /* tags */) serverScope := tally.NewTestScope("" /* prefix */, nil /* tags */) // Verify the tagged metrics from that call. tests := []struct { msg string scope tally.TestScope counters []string timers []string }{ { msg: "client metrics", scope: clientScope, counters: []string{ "outbound.calls.send+dest=testService,procedure=echo,source=testService-client", "outbound.calls.success+dest=testService,procedure=echo,source=testService-client", }, timers: []string{ "outbound.calls.per-attempt.latency+dest=testService,procedure=echo,source=testService-client", "outbound.calls.latency+dest=testService,procedure=echo,source=testService-client", }, }, { msg: "server metrics", scope: serverScope, counters: []string{ "inbound.calls.recvd+dest=testService,procedure=echo,source=testService-client", "inbound.calls.success+dest=testService,procedure=echo,source=testService-client", }, timers: []string{ "inbound.calls.latency+dest=testService,procedure=echo,source=testService-client", }, }, } // Use a closure so that the server/client are closed before we verify metrics. // Otherwise, we may attempt to verify metrics before they've been flushed by TChannel. func() { server := testutils.NewServer(t, testutils.NewOpts().SetStatsReporter(NewTallyReporter(serverScope))) defer server.Close() testutils.RegisterEcho(server, nil) client := testutils.NewClient(t, testutils.NewOpts().SetStatsReporter(NewTallyReporter(clientScope))) defer client.Close() testutils.AssertEcho(t, client, server.PeerInfo().HostPort, server.ServiceName()) }() for _, tt := range tests { snapshot := tt.scope.Snapshot() for _, counter := range tt.counters { assert.Contains(t, snapshot.Counters(), counter, "missing counter") } for _, timer := range tt.timers { assert.Contains(t, snapshot.Timers(), timer, "missing timer") } } } func BenchmarkTallyCounter(b *testing.B) { scope := tally.NewTestScope("" /* prefix */, nil /* tags */) wrapped := NewTallyReporter(scope) tags := map[string]string{ "target-service": "tsvc", "service": "foo", "target-endpoint": "te", } for i := 0; i < b.N; i++ { wrapped.IncCounter("outbound.calls", tags, 1) } } ================================================ FILE: stats.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "log" "time" ) // StatsReporter is the the interface used to report stats. type StatsReporter interface { IncCounter(name string, tags map[string]string, value int64) UpdateGauge(name string, tags map[string]string, value int64) RecordTimer(name string, tags map[string]string, d time.Duration) } // NullStatsReporter is a stats reporter that discards the statistics. var NullStatsReporter StatsReporter = nullStatsReporter{} type nullStatsReporter struct{} func (nullStatsReporter) IncCounter(name string, tags map[string]string, value int64) {} func (nullStatsReporter) UpdateGauge(name string, tags map[string]string, value int64) {} func (nullStatsReporter) RecordTimer(name string, tags map[string]string, d time.Duration) {} // SimpleStatsReporter is a stats reporter that reports stats to the log. var SimpleStatsReporter StatsReporter = simpleStatsReporter{} type simpleStatsReporter struct { commonTags map[string]string } func (simpleStatsReporter) IncCounter(name string, tags map[string]string, value int64) { log.Printf("Stats: IncCounter(%v, %v) +%v", name, tags, value) } func (simpleStatsReporter) UpdateGauge(name string, tags map[string]string, value int64) { log.Printf("Stats: UpdateGauge(%v, %v) = %v", name, tags, value) } func (simpleStatsReporter) RecordTimer(name string, tags map[string]string, d time.Duration) { log.Printf("Stats: RecordTimer(%v, %v) = %v", name, tags, d) } ================================================ FILE: stats_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "fmt" "os" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) func tagsForOutboundCall(serverCh *Channel, clientCh *Channel, method string) map[string]string { host, _ := os.Hostname() return map[string]string{ "app": clientCh.PeerInfo().ProcessName, "host": host, "service": clientCh.PeerInfo().ServiceName, "target-service": serverCh.PeerInfo().ServiceName, "target-endpoint": method, } } func tagsForInboundCall(serverCh *Channel, clientCh *Channel, method string) map[string]string { host, _ := os.Hostname() return map[string]string{ "app": serverCh.PeerInfo().ProcessName, "host": host, "service": serverCh.PeerInfo().ServiceName, "calling-service": clientCh.PeerInfo().ServiceName, "endpoint": method, } } // statsHandler increments the server and client timers when handling requests. type statsHandler struct { *testHandler clientClock *testutils.StubClock serverClock *testutils.StubClock } func (h *statsHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { h.clientClock.Elapse(100 * time.Millisecond) h.serverClock.Elapse(70 * time.Millisecond) return h.testHandler.Handle(ctx, args) } func TestStatsCalls(t *testing.T) { defer testutils.SetTimeout(t, 2*time.Second)() tests := []struct { method string wantErr bool }{ { method: "echo", }, { method: "app-error", wantErr: true, }, } for _, tt := range tests { initialTime := time.Date(2015, 2, 1, 10, 10, 0, 0, time.UTC) clientClock := testutils.NewStubClock(initialTime) serverClock := testutils.NewStubClock(initialTime) handler := &statsHandler{ testHandler: newTestHandler(t), clientClock: clientClock, serverClock: serverClock, } clientStats := newRecordingStatsReporter() serverStats := newRecordingStatsReporter() serverOpts := testutils.NewOpts(). SetStatsReporter(serverStats). SetTimeNow(serverClock.Now) WithVerifiedServer(t, serverOpts, func(serverCh *Channel, hostPort string) { handler := raw.Wrap(handler) serverCh.Register(handler, "echo") serverCh.Register(handler, "app-error") ch := testutils.NewClient(t, testutils.NewOpts(). SetStatsReporter(clientStats). SetTimeNow(clientClock.Now)) defer ch.Close() ctx, cancel := NewContext(time.Second * 5) defer cancel() _, _, resp, err := raw.Call(ctx, ch, hostPort, testutils.DefaultServerName, tt.method, nil, nil) require.NoError(t, err, "Call(%v) should fail", tt.method) assert.Equal(t, tt.wantErr, resp.ApplicationError(), "Call(%v) check application error") outboundTags := tagsForOutboundCall(serverCh, ch, tt.method) inboundTags := tagsForInboundCall(serverCh, ch, tt.method) clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, 1) clientStats.Expected.RecordTimer("outbound.calls.per-attempt.latency", outboundTags, 100*time.Millisecond) clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, 100*time.Millisecond) serverStats.Expected.IncCounter("inbound.calls.recvd", inboundTags, 1) serverStats.Expected.RecordTimer("inbound.calls.latency", inboundTags, 70*time.Millisecond) if tt.wantErr { clientStats.Expected.IncCounter("outbound.calls.per-attempt.app-errors", outboundTags, 1) clientStats.Expected.IncCounter("outbound.calls.app-errors", outboundTags, 1) serverStats.Expected.IncCounter("inbound.calls.app-errors", inboundTags, 1) } else { clientStats.Expected.IncCounter("outbound.calls.success", outboundTags, 1) serverStats.Expected.IncCounter("inbound.calls.success", inboundTags, 1) } }) clientStats.Validate(t) serverStats.Validate(t) } } func TestStatsWithRetries(t *testing.T) { defer testutils.SetTimeout(t, 2*time.Second)() a := testutils.DurationArray initialTime := time.Date(2015, 2, 1, 10, 10, 0, 0, time.UTC) clientClock := testutils.NewStubClock(initialTime) clientStats := newRecordingStatsReporter() ch := testutils.NewClient(t, testutils.NewOpts(). SetStatsReporter(clientStats). SetTimeNow(clientClock.Now)) defer ch.Close() ctx, cancel := NewContext(time.Second) defer cancel() // TODO why do we need this?? opts := testutils.NewOpts().NoRelay() WithVerifiedServer(t, opts, func(serverCh *Channel, hostPort string) { const ( perAttemptServer = 10 * time.Millisecond perAttemptClient = time.Millisecond perAttemptTotal = perAttemptServer + perAttemptClient ) respErr := make(chan error, 1) testutils.RegisterFunc(serverCh, "req", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { clientClock.Elapse(perAttemptServer) return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, <-respErr }) ch.Peers().Add(serverCh.PeerInfo().HostPort) tests := []struct { expectErr error numFailures int numAttempts int overallLatency time.Duration perAttemptLatencies []time.Duration }{ { numFailures: 0, numAttempts: 1, perAttemptLatencies: a(perAttemptServer), overallLatency: perAttemptTotal, }, { numFailures: 1, numAttempts: 2, perAttemptLatencies: a(perAttemptServer, perAttemptServer), overallLatency: 2 * perAttemptTotal, }, { numFailures: 4, numAttempts: 5, perAttemptLatencies: a(perAttemptServer, perAttemptServer, perAttemptServer, perAttemptServer, perAttemptServer), overallLatency: 5 * perAttemptTotal, }, { numFailures: 5, numAttempts: 5, expectErr: ErrServerBusy, perAttemptLatencies: a(perAttemptServer, perAttemptServer, perAttemptServer, perAttemptServer, perAttemptServer), overallLatency: 5 * perAttemptTotal, }, } for _, tt := range tests { clientStats.Reset() err := ch.RunWithRetry(ctx, func(ctx context.Context, rs *RequestState) error { clientClock.Elapse(perAttemptClient) if rs.Attempt > tt.numFailures { respErr <- nil } else { respErr <- ErrServerBusy } sc := ch.GetSubChannel(serverCh.ServiceName()) _, err := raw.CallV2(ctx, sc, raw.CArgs{ Method: "req", CallOptions: &CallOptions{RequestState: rs}, }) return err }) assert.Equal(t, tt.expectErr, err, "RunWithRetry unexpected error") outboundTags := tagsForOutboundCall(serverCh, ch, "req") if tt.expectErr == nil { clientStats.Expected.IncCounter("outbound.calls.success", outboundTags, 1) } clientStats.Expected.IncCounter("outbound.calls.send", outboundTags, int64(tt.numAttempts)) for i, latency := range tt.perAttemptLatencies { clientStats.Expected.RecordTimer("outbound.calls.per-attempt.latency", outboundTags, latency) if i > 0 { tags := tagsForOutboundCall(serverCh, ch, "req") tags["retry-count"] = fmt.Sprint(i) clientStats.Expected.IncCounter("outbound.calls.retries", tags, 1) } } clientStats.Expected.RecordTimer("outbound.calls.latency", outboundTags, tt.overallLatency) clientStats.Validate(t) } }) } ================================================ FILE: stats_utils_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test // This file contains test setup logic, and is named with a _test.go suffix to // ensure it's only compiled with tests. import ( "fmt" "reflect" "sort" "strings" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) type statsValue struct { // count is the counter value if this metric is a counter. count int64 // timers is the list of timer values if this metrics is a timer. timers []time.Duration } type recordingStatsReporter struct { sync.Mutex // Values is a map from the metricName -> map[tagMapAsString]*statsValue Values map[string]map[string]*statsValue // Expected stores expected counter values. Expected *recordingStatsReporter } func newRecordingStatsReporter() *recordingStatsReporter { return &recordingStatsReporter{ Values: make(map[string]map[string]*statsValue), Expected: &recordingStatsReporter{ Values: make(map[string]map[string]*statsValue), }, } } // keysMap returns the keys of the given map as a sorted list of strings. // If the map is not of the type map[string]* then the function will panic. func keysMap(m interface{}) []string { var keys []string mapKeys := reflect.ValueOf(m).MapKeys() for _, v := range mapKeys { keys = append(keys, v.Interface().(string)) } sort.Strings(keys) return keys } // tagsToString converts a map of tags to a string that can be used as a map key. func tagsToString(tags map[string]string) string { var vals []string for _, k := range keysMap(tags) { vals = append(vals, fmt.Sprintf("%v = %v", k, tags[k])) } return strings.Join(vals, ", ") } func (r *recordingStatsReporter) getStat(name string, tags map[string]string) *statsValue { r.Lock() defer r.Unlock() tagMap, ok := r.Values[name] if !ok { tagMap = make(map[string]*statsValue) r.Values[name] = tagMap } tagStr := tagsToString(tags) statVal, ok := tagMap[tagStr] if !ok { statVal = &statsValue{} tagMap[tagStr] = statVal } return statVal } func (r *recordingStatsReporter) IncCounter(name string, tags map[string]string, value int64) { statVal := r.getStat(name, tags) statVal.count += value } func (r *recordingStatsReporter) RecordTimer(name string, tags map[string]string, d time.Duration) { statVal := r.getStat(name, tags) statVal.timers = append(statVal.timers, d) } func (r *recordingStatsReporter) Reset() { newReporter := newRecordingStatsReporter() r.Values = newReporter.Values r.Expected = newReporter.Expected } func (r *recordingStatsReporter) Validate(t *testing.T) { r.Lock() defer r.Unlock() assert.Equal(t, keysMap(r.Expected.Values), keysMap(r.Values), "Metric keys are different") r.validateExpectedLocked(t) } // ValidateExpected only validates metrics added to expected rather than all recorded metrics. func (r *recordingStatsReporter) ValidateExpected(t testing.TB) { r.Lock() defer r.Unlock() r.validateExpectedLocked(t) } func (r *recordingStatsReporter) EnsureNotPresent(t testing.TB, counter string) { r.Lock() defer r.Unlock() assert.NotContains(t, r.Values, counter, "metric should not be present") } func (r *recordingStatsReporter) validateExpectedLocked(t testing.TB) { for counterKey, expectedCounter := range r.Expected.Values { counter, ok := r.Values[counterKey] if !assert.True(t, ok, "expected %v not found", counterKey) { continue } assert.Equal(t, keysMap(expectedCounter), keysMap(counter), "Metric %v has different reported tags", counterKey) for tags, stat := range counter { expectedStat, ok := expectedCounter[tags] if !ok { continue } assert.Equal(t, expectedStat, stat, "Metric %v with tags %v has mismatched value", counterKey, tags) } } } func (r *recordingStatsReporter) UpdateGauge(name string, tags map[string]string, value int64) {} ================================================ FILE: stream_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "errors" "fmt" "io" "io/ioutil" "strings" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) const ( streamRequestError = byte(255) streamRequestClose = byte(254) ) func makeRepeatedBytes(n byte) []byte { data := make([]byte, int(n)) for i := byte(0); i < n; i++ { data[i] = n } return data } func writeFlushBytes(w ArgWriter, bs []byte) error { if _, err := w.Write(bs); err != nil { return err } return w.Flush() } type streamHelper struct { t testing.TB } // startCall starts a call to echoStream and returns the arg3 reader and writer. func (h streamHelper) startCall(ctx context.Context, ch *Channel, hostPort, serviceName string) (ArgWriter, ArgReader) { call, err := ch.BeginCall(ctx, hostPort, serviceName, "echoStream", nil) require.NoError(h.t, err, "BeginCall to echoStream failed") // Write empty headers require.NoError(h.t, NewArgWriter(call.Arg2Writer()).Write(nil), "Write empty headers failed") // Flush arg3 to force the call to start without any arg3. writer, err := call.Arg3Writer() require.NoError(h.t, err, "Arg3Writer failed") require.NoError(h.t, writer.Flush(), "Arg3Writer flush failed") // Read empty Headers response := call.Response() var arg2 []byte require.NoError(h.t, NewArgReader(response.Arg2Reader()).Read(&arg2), "Read headers failed") require.False(h.t, response.ApplicationError(), "echoStream failed due to application error") reader, err := response.Arg3Reader() require.NoError(h.t, err, "Arg3Reader failed") return writer, reader } // streamPartialHandler returns a streaming handler that has the following contract: // read a byte, write N bytes where N = the byte that was read. // The results are be written as soon as the byte is read. func streamPartialHandler(t testing.TB, reportErrors bool) HandlerFunc { return func(ctx context.Context, call *InboundCall) { response := call.Response() onError := func(err error) { if reportErrors { t.Errorf("Handler error: %v", err) } response.SendSystemError(fmt.Errorf("failed to read arg2")) } var arg2 []byte if err := NewArgReader(call.Arg2Reader()).Read(&arg2); err != nil { onError(fmt.Errorf("failed to read arg2")) return } if err := NewArgWriter(response.Arg2Writer()).Write(nil); err != nil { onError(fmt.Errorf("")) return } argReader, err := call.Arg3Reader() if err != nil { onError(fmt.Errorf("failed to read arg3")) return } argWriter, err := response.Arg3Writer() if err != nil { onError(fmt.Errorf("arg3 writer failed")) return } // Flush arg3 which will force a frame with just arg2 to be sent. // The test reads arg2 before arg3 has been sent. if err := argWriter.Flush(); err != nil { onError(fmt.Errorf("arg3 flush failed")) return } arg3 := make([]byte, 1) for { n, err := argReader.Read(arg3) if err == io.EOF { break } if n == 0 && err == nil { err = fmt.Errorf("read 0 bytes") } if err != nil { onError(fmt.Errorf("arg3 Read failed: %v", err)) return } // Magic number to cause a failure if arg3[0] == streamRequestError { // Make sure that the reader is closed. if err := argReader.Close(); err != nil { onError(fmt.Errorf("request error failed to close argReader: %v", err)) return } response.SendSystemError(errors.New("intentional failure")) return } if arg3[0] == streamRequestClose { if err := argWriter.Close(); err != nil { onError(err) } return } // Write the number of bytes as specified by arg3[0] if _, err := argWriter.Write(makeRepeatedBytes(arg3[0])); err != nil { onError(fmt.Errorf("argWriter Write failed: %v", err)) return } if err := argWriter.Flush(); err != nil { onError(fmt.Errorf("argWriter flush failed: %v", err)) return } } if err := argReader.Close(); err != nil { onError(fmt.Errorf("argReader Close failed: %v", err)) return } if err := argWriter.Close(); err != nil { onError(fmt.Errorf("arg3writer Close failed: %v", err)) return } } } func testStreamArg(t *testing.T, f func(argWriter ArgWriter, argReader ArgReader)) { defer testutils.SetTimeout(t, 2*time.Second)() ctx, cancel := NewContext(time.Second) defer cancel() helper := streamHelper{t} WithVerifiedServer(t, nil, func(ch *Channel, hostPort string) { ch.Register(streamPartialHandler(t, true /* report errors */), "echoStream") argWriter, argReader := helper.startCall(ctx, ch, hostPort, ch.ServiceName()) verifyBytes := func(n byte) { require.NoError(t, writeFlushBytes(argWriter, []byte{n}), "arg3 write failed") arg3 := make([]byte, int(n)) _, err := io.ReadFull(argReader, arg3) require.NoError(t, err, "arg3 read failed") assert.Equal(t, makeRepeatedBytes(n), arg3, "arg3 result mismatch") } verifyBytes(0) verifyBytes(5) verifyBytes(100) verifyBytes(1) f(argWriter, argReader) }) } func TestStreamPartialArg(t *testing.T) { testStreamArg(t, func(argWriter ArgWriter, argReader ArgReader) { require.NoError(t, argWriter.Close(), "arg3 close failed") // Once closed, we expect the reader to return EOF n, err := io.Copy(ioutil.Discard, argReader) assert.Equal(t, int64(0), n, "arg2 reader expected to EOF after arg3 writer is closed") assert.NoError(t, err, "Copy should not fail") assert.NoError(t, argReader.Close(), "close arg reader failed") }) } func TestStreamSendError(t *testing.T) { testStreamArg(t, func(argWriter ArgWriter, argReader ArgReader) { // Send the magic number to request an error. _, err := argWriter.Write([]byte{streamRequestError}) require.NoError(t, err, "arg3 write failed") require.NoError(t, argWriter.Close(), "arg3 close failed") // Now we expect an error on our next read. _, err = ioutil.ReadAll(argReader) assert.Error(t, err, "ReadAll should fail") assert.True(t, strings.Contains(err.Error(), "intentional failure"), "err %v unexpected", err) }) } func TestStreamCancelled(t *testing.T) { // Since the cancel message is unimplemented, the relay does not know that the // call was cancelled, andwill block closing till the timeout. opts := testutils.NewOpts().NoRelay() testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { ts.Register(streamPartialHandler(t, false /* report errors */), "echoStream") ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() helper := streamHelper{t} client := ts.NewClient(nil) cancelContext := make(chan struct{}) arg3Writer, arg3Reader := helper.startCall(ctx, client, ts.HostPort(), ts.ServiceName()) go func() { for i := 0; i < 10; i++ { assert.NoError(t, writeFlushBytes(arg3Writer, []byte{1}), "Write failed") } // Our reads and writes should fail now. <-cancelContext cancel() _, err := arg3Writer.Write([]byte{1}) // The write will succeed since it's buffered. assert.NoError(t, err, "Write after fail should be buffered") assert.Error(t, arg3Writer.Flush(), "writer.Flush should fail after cancel") assert.Error(t, arg3Writer.Close(), "writer.Close should fail after cancel") }() for i := 0; i < 10; i++ { arg3 := make([]byte, 1) n, err := arg3Reader.Read(arg3) assert.Equal(t, 1, n, "Read did not correct number of bytes") assert.NoError(t, err, "Read failed") } close(cancelContext) n, err := io.Copy(ioutil.Discard, arg3Reader) assert.EqualValues(t, 0, n, "Read should not read any bytes after cancel") assert.Error(t, err, "Read should fail after cancel") assert.Error(t, arg3Reader.Close(), "reader.Close should fail after cancel") // Close the client to clear out the pending exchange. Otherwise the test // waits for the timeout, causing a slowdown. client.Close() }) } func TestResponseClosedBeforeRequest(t *testing.T) { testutils.WithTestServer(t, nil, func(t testing.TB, ts *testutils.TestServer) { ts.Register(streamPartialHandler(t, false /* report errors */), "echoStream") ctx, cancel := NewContext(testutils.Timeout(time.Second)) defer cancel() helper := streamHelper{t} ch := ts.NewClient(nil) responseClosed := make(chan struct{}) writerDone := make(chan struct{}) arg3Writer, arg3Reader := helper.startCall(ctx, ch, ts.HostPort(), ts.Server().ServiceName()) go func() { defer close(writerDone) for i := 0; i < 10; i++ { assert.NoError(t, writeFlushBytes(arg3Writer, []byte{1}), "Write failed") } // Ignore the error of writeFlushBytes here since once we flush, the // remote side could receive and close the response before we've created // a new fragment (see fragmentingWriter.Flush). This could result // in the Flush returning a "mex is already shutdown" error. writeFlushBytes(arg3Writer, []byte{streamRequestClose}) // Wait until our reader gets the EOF. <-responseClosed // Now our writes should fail, since the stream is shutdown err := writeFlushBytes(arg3Writer, []byte{1}) if assert.Error(t, err, "Req write should fail since response stream ended") { assert.Contains(t, err.Error(), "mex has been shutdown") } }() for i := 0; i < 10; i++ { arg3 := make([]byte, 1) n, err := arg3Reader.Read(arg3) assert.Equal(t, 1, n, "Read did not correct number of bytes") assert.NoError(t, err, "Read failed") } eofBuf := make([]byte, 1) _, err := arg3Reader.Read(eofBuf) assert.Equal(t, io.EOF, err, "Response should EOF after request close") assert.NoError(t, arg3Reader.Close(), "Close should succeed") close(responseClosed) <-writerDone }) } ================================================ FILE: stress_flag_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "flag" "testing" ) // This file contains functions for tests to access internal tchannel state. // Since it has a _test.go suffix, it is only compiled with tests in this package. var flagStressTest = flag.Bool("stressTest", false, "Run stress tests (very slow)") // CheckStress will skip the test if stress testing is not enabled. func CheckStress(t *testing.T) { if !*flagStressTest { t.Skip("Skipping long-running test as stressTest is not set") } } ================================================ FILE: subchannel.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "sync" "github.com/opentracing/opentracing-go" "golang.org/x/net/context" ) // SubChannelOption are used to set options for subchannels. type SubChannelOption func(*SubChannel) // Isolated is a SubChannelOption that creates an isolated subchannel. func Isolated(s *SubChannel) { s.Lock() s.peers = s.topChannel.peers.newSibling() s.peers.SetStrategy(newLeastPendingCalculator()) s.Unlock() } // SubChannel allows calling a specific service on a channel. // TODO(prashant): Allow creating a subchannel with default call options. // TODO(prashant): Allow registering handlers on a subchannel. type SubChannel struct { sync.RWMutex serviceName string topChannel *Channel defaultCallOptions *CallOptions peers *PeerList handler Handler logger Logger statsReporter StatsReporter } // Map of subchannel and the corresponding service type subChannelMap struct { sync.RWMutex subchannels map[string]*SubChannel } func newSubChannel(serviceName string, ch *Channel) *SubChannel { logger := ch.Logger().WithFields(LogField{"subchannel", serviceName}) return &SubChannel{ serviceName: serviceName, peers: ch.peers, topChannel: ch, handler: &handlerMap{}, // use handlerMap by default logger: logger, statsReporter: ch.StatsReporter(), } } // ServiceName returns the service name that this subchannel is for. func (c *SubChannel) ServiceName() string { return c.serviceName } // BeginCall starts a new call to a remote peer, returning an OutboundCall that can // be used to write the arguments of the call. func (c *SubChannel) BeginCall(ctx context.Context, methodName string, callOptions *CallOptions) (*OutboundCall, error) { if callOptions == nil { callOptions = defaultCallOptions } peer, err := c.peers.Get(callOptions.RequestState.PrevSelectedPeers()) if err != nil { return nil, err } return peer.BeginCall(ctx, c.ServiceName(), methodName, callOptions) } // Peers returns the PeerList for this subchannel. func (c *SubChannel) Peers() *PeerList { return c.peers } // Isolated returns whether this subchannel is an isolated subchannel. func (c *SubChannel) Isolated() bool { c.RLock() defer c.RUnlock() return c.topChannel.Peers() != c.peers } // Register registers a handler on the subchannel for the given method. // // This function panics if the Handler for the SubChannel was overwritten with // SetHandler. func (c *SubChannel) Register(h Handler, methodName string) { r, ok := c.handler.(registrar) if !ok { panic(fmt.Sprintf( "handler for SubChannel(%v) configured with alternate root handler without Register method", c.ServiceName(), )) } r.Register(h, methodName) } // GetHandlers returns all handlers registered on this subchannel by method name. // // This function panics if the Handler for the SubChannel was overwritten with // SetHandler. func (c *SubChannel) GetHandlers() map[string]Handler { handlers, ok := c.handler.(*handlerMap) if !ok { panic(fmt.Sprintf( "handler for SubChannel(%v) was changed to disallow method registration", c.ServiceName(), )) } handlers.RLock() handlersMap := make(map[string]Handler, len(handlers.handlers)) for k, v := range handlers.handlers { handlersMap[k] = v } handlers.RUnlock() return handlersMap } // SetHandler changes the SubChannel's underlying handler. This may be used to // set up a catch-all Handler for all requests received by this SubChannel. // // Methods registered on this SubChannel using Register() before calling // SetHandler() will be forgotten. Further calls to Register() on this // SubChannel after SetHandler() is called will cause panics. func (c *SubChannel) SetHandler(h Handler) { c.handler = h } // Logger returns the logger for this subchannel. func (c *SubChannel) Logger() Logger { return c.logger } // StatsReporter returns the stats reporter for this subchannel. func (c *SubChannel) StatsReporter() StatsReporter { return c.topChannel.StatsReporter() } // StatsTags returns the stats tags for this subchannel. func (c *SubChannel) StatsTags() map[string]string { tags := c.topChannel.StatsTags() tags["subchannel"] = c.serviceName return tags } // Tracer returns OpenTracing Tracer from the top channel. func (c *SubChannel) Tracer() opentracing.Tracer { return c.topChannel.Tracer() } // Register a new subchannel for the given serviceName func (subChMap *subChannelMap) registerNewSubChannel(serviceName string, ch *Channel) (_ *SubChannel, added bool) { subChMap.Lock() defer subChMap.Unlock() if subChMap.subchannels == nil { subChMap.subchannels = make(map[string]*SubChannel) } if sc, ok := subChMap.subchannels[serviceName]; ok { return sc, false } sc := newSubChannel(serviceName, ch) subChMap.subchannels[serviceName] = sc return sc, true } // Get subchannel if, we have one func (subChMap *subChannelMap) get(serviceName string) (*SubChannel, bool) { subChMap.RLock() sc, ok := subChMap.subchannels[serviceName] subChMap.RUnlock() return sc, ok } // GetOrAdd a subchannel for the given serviceName on the map func (subChMap *subChannelMap) getOrAdd(serviceName string, ch *Channel) (_ *SubChannel, added bool) { if sc, ok := subChMap.get(serviceName); ok { return sc, false } return subChMap.registerNewSubChannel(serviceName, ch) } func (subChMap *subChannelMap) updatePeer(p *Peer) { subChMap.RLock() for _, subCh := range subChMap.subchannels { if subCh.Isolated() { subCh.RLock() subCh.Peers().onPeerChange(p) subCh.RUnlock() } } subChMap.RUnlock() } ================================================ FILE: subchannel_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) type chanSet struct { main Registrar sub Registrar isolated Registrar } func withNewSet(t *testing.T, f func(*testing.T, chanSet)) { ch := testutils.NewClient(t, nil) defer ch.Close() f(t, chanSet{ main: ch, sub: ch.GetSubChannel("hyperbahn"), isolated: ch.GetSubChannel("ringpop", Isolated), }) } // Assert that two Registrars have references to the same Peer. func assertHaveSameRef(t *testing.T, r1, r2 Registrar) { p1, err := r1.Peers().Get(nil) assert.NoError(t, err, "First registrar has no peers.") p2, err := r2.Peers().Get(nil) assert.NoError(t, err, "Second registrar has no peers.") assert.True(t, p1 == p2, "Registrars have references to different peers.") } func assertNoPeer(t *testing.T, r Registrar) { _, err := r.Peers().Get(nil) assert.Equal(t, err, ErrNoPeers) } func TestMainAddVisibility(t *testing.T) { withNewSet(t, func(t *testing.T, set chanSet) { // Adding a peer to the main channel should be reflected in the // subchannel, but not the isolated subchannel. set.main.Peers().Add("127.0.0.1:3000") assertHaveSameRef(t, set.main, set.sub) assertNoPeer(t, set.isolated) }) } func TestSubchannelAddVisibility(t *testing.T) { withNewSet(t, func(t *testing.T, set chanSet) { // Adding a peer to a non-isolated subchannel should be reflected in // the main channel but not in isolated siblings. set.sub.Peers().Add("127.0.0.1:3000") assertHaveSameRef(t, set.main, set.sub) assertNoPeer(t, set.isolated) }) } func TestIsolatedAddVisibility(t *testing.T) { withNewSet(t, func(t *testing.T, set chanSet) { // Adding a peer to an isolated subchannel shouldn't change the main // channel or sibling channels. set.isolated.Peers().Add("127.0.0.1:3000") _, err := set.isolated.Peers().Get(nil) assert.NoError(t, err) assertNoPeer(t, set.main) assertNoPeer(t, set.sub) }) } func TestAddReusesPeers(t *testing.T) { withNewSet(t, func(t *testing.T, set chanSet) { // Adding to both a channel and an isolated subchannel shouldn't create // two separate peers. set.main.Peers().Add("127.0.0.1:3000") set.isolated.Peers().Add("127.0.0.1:3000") assertHaveSameRef(t, set.main, set.sub) assertHaveSameRef(t, set.main, set.isolated) }) } func TestSetHandler(t *testing.T) { // Generate a Handler that expects only the given methods to be called. genHandler := func(methods ...string) Handler { allowedMethods := make(map[string]struct{}, len(methods)) for _, m := range methods { allowedMethods[m] = struct{}{} } return HandlerFunc(func(ctx context.Context, call *InboundCall) { method := call.MethodString() assert.Contains(t, allowedMethods, method, "unexpected call to %q", method) err := raw.WriteResponse(call.Response(), &raw.Res{Arg3: []byte(method)}) require.NoError(t, err) }) } ch := testutils.NewServer(t, testutils.NewOpts(). AddLogFilter("Couldn't find handler", 1, "serviceName", "svc2", "method", "bar")) defer ch.Close() // Catch-all handler for the main channel that accepts foo, bar, and baz, // and a single registered handler for a different subchannel. ch.GetSubChannel("svc1").SetHandler(genHandler("foo", "bar", "baz")) ch.GetSubChannel("svc2").Register(genHandler("foo"), "foo") client := testutils.NewClient(t, nil) client.Peers().Add(ch.PeerInfo().HostPort) defer client.Close() tests := []struct { Service string Method string ShouldFail bool }{ {"svc1", "foo", false}, {"svc1", "bar", false}, {"svc1", "baz", false}, {"svc2", "foo", false}, {"svc2", "bar", true}, } for _, tt := range tests { c := client.GetSubChannel(tt.Service) ctx, _ := NewContext(time.Second) _, data, _, err := raw.CallSC(ctx, c, tt.Method, nil, []byte("irrelevant")) if tt.ShouldFail { require.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tt.Method, string(data)) } } st := ch.IntrospectState(nil) assert.Equal(t, "overriden", st.SubChannels["svc1"].Handler.Type.String()) assert.Nil(t, st.SubChannels["svc1"].Handler.Methods) assert.Equal(t, "methods", st.SubChannels["svc2"].Handler.Type.String()) assert.Equal(t, []string{"foo"}, st.SubChannels["svc2"].Handler.Methods) } func TestGetHandlers(t *testing.T) { ch := testutils.NewServer(t, nil) defer ch.Close() var handler1 HandlerFunc = func(_ context.Context, _ *InboundCall) { panic("unexpected call") } var handler2 HandlerFunc = func(_ context.Context, _ *InboundCall) { panic("unexpected call") } ch.Register(handler1, "method1") ch.Register(handler2, "method2") ch.GetSubChannel("foo").Register(handler2, "method1") tests := []struct { serviceName string wantMethods []string }{ { serviceName: ch.ServiceName(), wantMethods: []string{"_gometa_introspect", "_gometa_runtime", "method1", "method2"}, }, { serviceName: "foo", wantMethods: []string{"method1"}, }, } for _, tt := range tests { handlers := ch.GetSubChannel(tt.serviceName).GetHandlers() if !assert.Equal(t, len(tt.wantMethods), len(handlers), "Unexpected number of methods found, expected %v, got %v", tt.wantMethods, handlers) { continue } for _, method := range tt.wantMethods { _, ok := handlers[method] assert.True(t, ok, "Expected to find method %v in handlers: %v", method, handlers) } } } func TestCannotRegisterOrGetAfterSetHandler(t *testing.T) { ch := testutils.NewServer(t, nil) defer ch.Close() var someHandler HandlerFunc = func(ctx context.Context, call *InboundCall) { panic("unexpected call") } var anotherHandler HandlerFunc = func(ctx context.Context, call *InboundCall) { panic("unexpected call") } ch.GetSubChannel("foo").SetHandler(someHandler) // Registering against the original service should not panic but // registering against the "foo" service should panic since the handler // was overridden, and doesn't support Register. assert.NotPanics(t, func() { ch.Register(anotherHandler, "bar") }) assert.NotPanics(t, func() { ch.GetSubChannel("svc").GetHandlers() }) assert.Panics(t, func() { ch.GetSubChannel("foo").Register(anotherHandler, "bar") }) assert.Panics(t, func() { ch.GetSubChannel("foo").GetHandlers() }) } func TestGetSubchannelOptionsOnNew(t *testing.T) { ch := testutils.NewServer(t, nil) defer ch.Close() peers := ch.GetSubChannel("s", Isolated).Peers() want := peers.Add("1.1.1.1:1") peers2 := ch.GetSubChannel("s", Isolated).Peers() assert.Equal(t, peers, peers2, "Get isolated subchannel should not clear existing peers") peer, err := peers2.Get(nil) require.NoError(t, err, "Should get peer") assert.Equal(t, want, peer, "Unexpected peer") } func TestHandlerWithoutSubChannel(t *testing.T) { opts := testutils.NewOpts().NoRelay() opts.Handler = raw.Wrap(newTestHandler(t)) testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { client := ts.NewClient(nil) testutils.AssertEcho(t, client, ts.HostPort(), ts.ServiceName()) testutils.AssertEcho(t, client, ts.HostPort(), "larry") testutils.AssertEcho(t, client, ts.HostPort(), "curly") testutils.AssertEcho(t, client, ts.HostPort(), "moe") assert.Panics(t, func() { ts.Server().Register(raw.Wrap(newTestHandler(t)), "nyuck") }) }) } type handlerWithRegister struct { registered map[string]struct{} } func (handlerWithRegister) Handle(ctx context.Context, call *InboundCall) { panic("Handle not expected to be called") } func (hr *handlerWithRegister) Register(h Handler, methodName string) { if hr.registered == nil { hr.registered = make(map[string]struct{}) } hr.registered[methodName] = struct{}{} } func TestHandlerCustomRegister(t *testing.T) { hrTop := &handlerWithRegister{} hrSC := &handlerWithRegister{} opts := testutils.NewOpts() opts.ChannelOptions.Handler = hrTop ch := testutils.NewServer(t, opts) defer ch.Close() var unused HandlerFunc = func(_ context.Context, _ *InboundCall) { panic("unexpected call") } ch.Register(unused, "Top-Method") sc := ch.GetSubChannel("sc") sc.SetHandler(hrSC) sc.Register(unused, "SC-Method") assert.Equal(t, map[string]struct{}{ "Top-Method": {}, }, hrTop.registered, "Register on top channel mismatch") assert.Equal(t, map[string]struct{}{ "SC-Method": {}, }, hrSC.registered, "Register on subchannel mismatch") } ================================================ FILE: systemerrcode_string.go ================================================ // generated by stringer -type=SystemErrCode; DO NOT EDIT package tchannel import "fmt" const ( _SystemErrCode_name_0 = "ErrCodeInvalidErrCodeTimeoutErrCodeCancelledErrCodeBusyErrCodeDeclinedErrCodeUnexpectedErrCodeBadRequestErrCodeNetwork" _SystemErrCode_name_1 = "ErrCodeProtocol" ) var ( _SystemErrCode_index_0 = [...]uint8{0, 14, 28, 44, 55, 70, 87, 104, 118} _SystemErrCode_index_1 = [...]uint8{0, 15} ) func (i SystemErrCode) String() string { switch { case 0 <= i && i <= 7: return _SystemErrCode_name_0[_SystemErrCode_index_0[i]:_SystemErrCode_index_0[i+1]] case i == 255: return _SystemErrCode_name_1 default: return fmt.Sprintf("SystemErrCode(%d)", i) } } ================================================ FILE: tchannel_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "flag" "fmt" "os" "testing" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils/goroutines" ) func checkAllChannels() error { ch, err := NewChannel("test-end", nil) if err != nil { return err } var foundChannels bool allChannels := ch.IntrospectOthers(&IntrospectionOptions{}) for _, cs := range allChannels { if len(cs) != 0 { foundChannels = true } } if !foundChannels { return nil } return fmt.Errorf("unclosed channels:\n%+v", allChannels) } func TestMain(m *testing.M) { flag.Parse() exitCode := m.Run() if exitCode == 0 { // Only do extra checks if the tests were successful. if err := goroutines.IdentifyLeaks(nil); err != nil { fmt.Fprintf(os.Stderr, "Found goroutine leaks on successful test run: %v", err) exitCode = 1 } if err := checkAllChannels(); err != nil { fmt.Fprintf(os.Stderr, "Found unclosed channels on successful test run: %v", err) exitCode = 1 } } os.Exit(exitCode) } ================================================ FILE: testutils/call.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/relay" "github.com/uber/tchannel-go/testutils/thriftarg2test" "github.com/uber/tchannel-go/thrift/arg2" ) // FakeIncomingCall implements IncomingCall interface. // Note: the F suffix for the fields is to clash with the method name. type FakeIncomingCall struct { // CallerNameF is the calling service's name. CallerNameF string // ShardKeyF is the intended destination for this call. ShardKeyF string // RemotePeerF is the calling service's peer info. RemotePeerF tchannel.PeerInfo // LocalPeerF is the local service's peer info. LocalPeerF tchannel.LocalPeerInfo // RoutingKeyF is the routing key. RoutingKeyF string // RoutingDelegateF is the routing delegate. RoutingDelegateF string } // CallerName returns the caller name as specified in the fake call. func (f *FakeIncomingCall) CallerName() string { return f.CallerNameF } // ShardKey returns the shard key as specified in the fake call. func (f *FakeIncomingCall) ShardKey() string { return f.ShardKeyF } // RoutingKey returns the routing delegate as specified in the fake call. func (f *FakeIncomingCall) RoutingKey() string { return f.RoutingKeyF } // RoutingDelegate returns the routing delegate as specified in the fake call. func (f *FakeIncomingCall) RoutingDelegate() string { return f.RoutingDelegateF } // LocalPeer returns the local peer information for this call. func (f *FakeIncomingCall) LocalPeer() tchannel.LocalPeerInfo { return f.LocalPeerF } // RemotePeer returns the remote peer information for this call. func (f *FakeIncomingCall) RemotePeer() tchannel.PeerInfo { return f.RemotePeerF } // CallOptions returns the incoming call options suitable for proxying a request. func (f *FakeIncomingCall) CallOptions() *tchannel.CallOptions { return &tchannel.CallOptions{ ShardKey: f.ShardKey(), RoutingKey: f.RoutingKey(), RoutingDelegate: f.RoutingDelegate(), } } // NewIncomingCall creates an incoming call for tests. func NewIncomingCall(callerName string) tchannel.IncomingCall { return &FakeIncomingCall{CallerNameF: callerName} } // FakeCallFrame is a stub implementation of the CallFrame interface. type FakeCallFrame struct { tb testing.TB TTLF time.Duration ServiceF, MethodF, CallerF, RoutingKeyF, RoutingDelegateF string Arg2StartOffsetVal, Arg2EndOffsetVal int IsArg2Fragmented bool arg2KVIterator arg2.KeyValIterator hasArg2KVIterator error Arg2Appends []relay.KeyVal } var _ relay.CallFrame = &FakeCallFrame{} // TTL returns the TTL field. func (f *FakeCallFrame) TTL() time.Duration { return f.TTLF } // Service returns the service name field. func (f *FakeCallFrame) Service() []byte { return []byte(f.ServiceF) } // Method returns the method field. func (f *FakeCallFrame) Method() []byte { return []byte(f.MethodF) } // Caller returns the caller field. func (f *FakeCallFrame) Caller() []byte { return []byte(f.CallerF) } // RoutingKey returns the routing delegate field. func (f *FakeCallFrame) RoutingKey() []byte { return []byte(f.RoutingKeyF) } // RoutingDelegate returns the routing delegate field. func (f *FakeCallFrame) RoutingDelegate() []byte { return []byte(f.RoutingDelegateF) } // Arg2StartOffset returns the offset from start of payload to // the beginning of Arg2. func (f *FakeCallFrame) Arg2StartOffset() int { return f.Arg2StartOffsetVal } // Arg2EndOffset returns the offset from start of payload to the end // of Arg2 and whether Arg2 is fragmented. func (f *FakeCallFrame) Arg2EndOffset() (int, bool) { return f.Arg2EndOffsetVal, f.IsArg2Fragmented } // Arg2Iterator returns the iterator for reading Arg2 key value pair // of TChannel-Thrift Arg Scheme. func (f *FakeCallFrame) Arg2Iterator() (arg2.KeyValIterator, error) { return f.arg2KVIterator, f.hasArg2KVIterator } // Arg2Append appends a key value pair to Arg2 func (f *FakeCallFrame) Arg2Append(key, val []byte) { f.Arg2Appends = append(f.Arg2Appends, relay.KeyVal{Key: key, Val: val}) } // CopyCallFrame copies the relay.CallFrame and returns a FakeCallFrame with // corresponding values func CopyCallFrame(f relay.CallFrame) *FakeCallFrame { endOffset, hasMore := f.Arg2EndOffset() copyIterator, err := copyThriftArg2KVIterator(f) return &FakeCallFrame{ TTLF: f.TTL(), ServiceF: string(f.Service()), MethodF: string(f.Method()), CallerF: string(f.Caller()), RoutingKeyF: string(f.RoutingKey()), RoutingDelegateF: string(f.RoutingDelegate()), Arg2StartOffsetVal: f.Arg2StartOffset(), Arg2EndOffsetVal: endOffset, IsArg2Fragmented: hasMore, arg2KVIterator: copyIterator, hasArg2KVIterator: err, } } // copyThriftArg2KVIterator uses the CallFrame Arg2Iterator to make a // deep-copy KeyValIterator. func copyThriftArg2KVIterator(f relay.CallFrame) (arg2.KeyValIterator, error) { kv := make(map[string]string) for iter, err := f.Arg2Iterator(); err == nil; iter, err = iter.Next() { kv[string(iter.Key())] = string(iter.Value()) } return arg2.NewKeyValIterator(thriftarg2test.BuildKVBuffer(kv)) } ================================================ FILE: testutils/channel.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "crypto/tls" "encoding/json" "fmt" "net" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/internal/testcert" "github.com/uber/tchannel-go/raw" "go.uber.org/atomic" "golang.org/x/net/context" ) // NewServerChannel creates a TChannel that is listening and returns the channel. // Passed in options may be mutated (for post-verification of state). func NewServerChannel(opts *ChannelOpts) (*tchannel.Channel, error) { opts = opts.Copy() l, err := getListener(opts.ServeTLS) if err != nil { return nil, fmt.Errorf("failed to listen: %v", err) } _, port, err := net.SplitHostPort(l.Addr().String()) if err != nil { return nil, fmt.Errorf("could not get listening port from %v: %v", l.Addr().String(), err) } serviceName := defaultString(opts.ServiceName, DefaultServerName) opts.ProcessName = defaultString(opts.ProcessName, serviceName+"-"+port) updateOptsLogger(opts) ch, err := tchannel.NewChannel(serviceName, &opts.ChannelOptions) if err != nil { return nil, fmt.Errorf("NewChannel failed: %v", err) } if err := ch.Serve(l); err != nil { return nil, fmt.Errorf("Serve failed: %v", err) } return ch, nil } var totalClients atomic.Uint32 // NewClientChannel creates a TChannel that is not listening. // Passed in options may be mutated (for post-verification of state). func NewClientChannel(opts *ChannelOpts) (*tchannel.Channel, error) { opts = opts.Copy() clientNum := totalClients.Inc() serviceName := defaultString(opts.ServiceName, DefaultClientName) opts.ProcessName = defaultString(opts.ProcessName, serviceName+"-"+fmt.Sprint(clientNum)) updateOptsLogger(opts) return tchannel.NewChannel(serviceName, &opts.ChannelOptions) } type rawFuncHandler struct { ch tchannel.Registrar f func(context.Context, *raw.Args) (*raw.Res, error) } func (h rawFuncHandler) OnError(ctx context.Context, err error) { h.ch.Logger().WithFields( tchannel.LogField{Key: "context", Value: ctx}, tchannel.ErrField(err), ).Error("simpleHandler OnError.") } func (h rawFuncHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { return h.f(ctx, args) } // RegisterFunc registers a function as a handler for the given method name. func RegisterFunc(ch tchannel.Registrar, name string, f func(ctx context.Context, args *raw.Args) (*raw.Res, error)) { ch.Register(raw.Wrap(rawFuncHandler{ch, f}), name) } // IntrospectJSON returns the introspected state of the channel as a JSON string. func IntrospectJSON(ch *tchannel.Channel, opts *tchannel.IntrospectionOptions) string { state := ch.IntrospectState(opts) marshalled, err := json.MarshalIndent(state, "", " ") if err != nil { return fmt.Sprintf("failed to marshal introspected state: %v", err) } return string(marshalled) } func getListener(serveTLS bool) (net.Listener, error) { if serveTLS { return getTLSListener() } return net.Listen("tcp", "127.0.0.1:0") } func getTLSListener() (net.Listener, error) { cert, err := tls.X509KeyPair(testcert.TestCert, testcert.TestKey) if err != nil { panic(fmt.Sprintf("testutils: getTLSListener: %v", err)) } return tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ Certificates: []tls.Certificate{cert}, }) } ================================================ FILE: testutils/channel_opts.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "flag" "net" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/tos" "go.uber.org/atomic" "golang.org/x/net/context" ) var connectionLog = flag.Bool("connectionLog", false, "Enables connection logging in tests") // Default service names for the test channels. const ( DefaultServerName = "testService" DefaultClientName = "testService-client" ) // ChannelOpts contains options to create a test channel using WithServer type ChannelOpts struct { tchannel.ChannelOptions // ServiceName defaults to DefaultServerName or DefaultClientName. ServiceName string // LogVerification contains options for controlling the log verification. LogVerification LogVerification // DisableRelay disables the relay interposed between clients/servers. // By default, all tests are run with a relay interposed. DisableRelay bool // DisableServer disables creation of the TChannel server. // This is typically only used in relay tests when a custom server is required. DisableServer bool // OnlyRelay instructs TestServer the test must only be run with a relay. OnlyRelay bool // RunCount is the number of times the test should be run. Zero or // negative values are treated as a single run. RunCount int // CheckFramePooling indicates whether we should check for frame leaks or not. // This causes the same tests to be run twice, first with the default frame pool, // then with the recording frame pool, which will double the amount of time it takes // for the test. CheckFramePooling bool // postFns is a list of functions that are run after the test. // They are run even if the test fails. postFns []func() // ServeTLS enables TLS support on server channel with test certs ServeTLS bool } // LogVerification contains options to control the log verification. type LogVerification struct { Disabled bool Filters []LogFilter } // LogFilter is a single substring match that can be ignored. type LogFilter struct { // Filter specifies the substring match to search // for in the log message to skip raising an error. Filter string // Count is the maximum number of allowed warn+ logs matching // Filter before errors are raised. Count uint // FieldFilters specifies expected substring matches for fields. FieldFilters map[string]string } // Copy copies the channel options (so that they can be safely modified). func (o *ChannelOpts) Copy() *ChannelOpts { if o == nil { return NewOpts() } copiedOpts := *o return &copiedOpts } // SetServiceName sets ServiceName. func (o *ChannelOpts) SetServiceName(svcName string) *ChannelOpts { o.ServiceName = svcName return o } // SetProcessName sets the ProcessName in ChannelOptions. func (o *ChannelOpts) SetProcessName(processName string) *ChannelOpts { o.ProcessName = processName return o } // SetStatsReporter sets StatsReporter in ChannelOptions. func (o *ChannelOpts) SetStatsReporter(statsReporter tchannel.StatsReporter) *ChannelOpts { o.StatsReporter = statsReporter return o } // SetFramePool sets FramePool in DefaultConnectionOptions. func (o *ChannelOpts) SetFramePool(framePool tchannel.FramePool) *ChannelOpts { o.DefaultConnectionOptions.FramePool = framePool return o } // SetHealthChecks sets HealthChecks in DefaultConnectionOptions. func (o *ChannelOpts) SetHealthChecks(healthChecks tchannel.HealthCheckOptions) *ChannelOpts { o.DefaultConnectionOptions.HealthChecks = healthChecks return o } // SetSendBufferSize sets the SendBufferSize in DefaultConnectionOptions. func (o *ChannelOpts) SetSendBufferSize(bufSize int) *ChannelOpts { o.DefaultConnectionOptions.SendBufferSize = bufSize return o } // SetSendBufferSizeOverrides sets the SendBufferOverrides in DefaultConnectionOptions. func (o *ChannelOpts) SetSendBufferSizeOverrides(overrides []tchannel.SendBufferSizeOverride) *ChannelOpts { o.DefaultConnectionOptions.SendBufferSizeOverrides = overrides return o } // SetTosPriority set TosPriority in DefaultConnectionOptions. func (o *ChannelOpts) SetTosPriority(tosPriority tos.ToS) *ChannelOpts { o.DefaultConnectionOptions.TosPriority = tosPriority return o } // SetChecksumType sets the ChecksumType in DefaultConnectionOptions. func (o *ChannelOpts) SetChecksumType(checksumType tchannel.ChecksumType) *ChannelOpts { o.DefaultConnectionOptions.ChecksumType = checksumType return o } // SetTimeNow sets TimeNow in ChannelOptions. func (o *ChannelOpts) SetTimeNow(timeNow func() time.Time) *ChannelOpts { o.TimeNow = timeNow return o } // SetTimeTicker sets TimeTicker in ChannelOptions. func (o *ChannelOpts) SetTimeTicker(timeTicker func(d time.Duration) *time.Ticker) *ChannelOpts { o.TimeTicker = timeTicker return o } // DisableLogVerification disables log verification for this channel. func (o *ChannelOpts) DisableLogVerification() *ChannelOpts { o.LogVerification.Disabled = true return o } // NoRelay disables running the test with a relay interposed. func (o *ChannelOpts) NoRelay() *ChannelOpts { o.DisableRelay = true return o } // SetRelayOnly instructs TestServer to only run with a relay in front of this channel. func (o *ChannelOpts) SetRelayOnly() *ChannelOpts { o.OnlyRelay = true return o } // SetDisableServer disables creation of the TChannel server. // This is typically only used in relay tests when a custom server is required. func (o *ChannelOpts) SetDisableServer() *ChannelOpts { o.DisableServer = true return o } // SetRunCount sets the number of times run the test. func (o *ChannelOpts) SetRunCount(n int) *ChannelOpts { o.RunCount = n return o } // AddLogFilter sets an allowed filter for warning/error logs and sets // the maximum number of times that log can occur. func (o *ChannelOpts) AddLogFilter(filter string, maxCount uint, fields ...string) *ChannelOpts { fieldFilters := make(map[string]string) for i := 0; i < len(fields); i += 2 { fieldFilters[fields[i]] = fields[i+1] } o.LogVerification.Filters = append(o.LogVerification.Filters, LogFilter{ Filter: filter, Count: maxCount, FieldFilters: fieldFilters, }) return o } func (o *ChannelOpts) addPostFn(f func()) { o.postFns = append(o.postFns, f) } // SetRelayHost sets the channel's RelayHost, which enables relaying. func (o *ChannelOpts) SetRelayHost(rh tchannel.RelayHost) *ChannelOpts { o.ChannelOptions.RelayHost = rh return o } // SetRelayLocal sets the channel's relay local handlers for service names // that should be handled by the relay channel itself. func (o *ChannelOpts) SetRelayLocal(relayLocal ...string) *ChannelOpts { o.ChannelOptions.RelayLocalHandlers = relayLocal return o } // SetRelayMaxTimeout sets the maximum allowable timeout for relayed calls. func (o *ChannelOpts) SetRelayMaxTimeout(d time.Duration) *ChannelOpts { o.ChannelOptions.RelayMaxTimeout = d return o } // SetRelayMaxConnectionTimeout sets the maximum timeout for connection attempts. func (o *ChannelOpts) SetRelayMaxConnectionTimeout(d time.Duration) *ChannelOpts { o.ChannelOptions.RelayMaxConnectionTimeout = d return o } // SetRelayMaxTombs sets the maximum number of tombs tracked in the relayer. func (o *ChannelOpts) SetRelayMaxTombs(maxTombs uint64) *ChannelOpts { o.ChannelOptions.RelayMaxTombs = maxTombs return o } // SetOnPeerStatusChanged sets the callback for channel status change // noficiations. func (o *ChannelOpts) SetOnPeerStatusChanged(f func(*tchannel.Peer)) *ChannelOpts { o.ChannelOptions.OnPeerStatusChanged = f return o } // SetMaxIdleTime sets a threshold after which idle connections will // automatically get dropped. See idle_sweep.go for more details. func (o *ChannelOpts) SetMaxIdleTime(d time.Duration) *ChannelOpts { o.ChannelOptions.MaxIdleTime = d return o } // SetIdleCheckInterval sets the frequency of the periodic poller that removes // stale connections from the channel. func (o *ChannelOpts) SetIdleCheckInterval(d time.Duration) *ChannelOpts { o.ChannelOptions.IdleCheckInterval = d return o } // SetDialer sets the dialer used for outbound connections func (o *ChannelOpts) SetDialer(f func(context.Context, string, string) (net.Conn, error)) *ChannelOpts { o.ChannelOptions.Dialer = f return o } // SetConnContext sets the connection's ConnContext function func (o *ChannelOpts) SetConnContext(f func(context.Context, net.Conn) context.Context) *ChannelOpts { o.ConnContext = f return o } // SetCheckFramePooling sets a flag to enable frame pooling checks such as leaks or bad releases func (o *ChannelOpts) SetCheckFramePooling() *ChannelOpts { o.CheckFramePooling = true return o } // SetServeTLS sets the ServeTLS flag to enable/disable TLS for test server func (o *ChannelOpts) SetServeTLS(serveTLS bool) *ChannelOpts { o.ServeTLS = serveTLS return o } func defaultString(v string, defaultValue string) string { if v == "" { return defaultValue } return v } // NewOpts returns a new ChannelOpts that can be used in a chained fashion. func NewOpts() *ChannelOpts { return &ChannelOpts{} } // DefaultOpts will return opts if opts is non-nil, NewOpts otherwise. func DefaultOpts(opts *ChannelOpts) *ChannelOpts { if opts == nil { return NewOpts() } return opts } // WrapLogger wraps the given logger with extra verification. func (v *LogVerification) WrapLogger(t testing.TB, l tchannel.Logger) tchannel.Logger { return errorLogger{l, t, v, &errorLoggerState{ matchCount: make([]atomic.Uint32, len(v.Filters)), }} } ================================================ FILE: testutils/channel_t.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "testing" "github.com/uber/tchannel-go" "github.com/stretchr/testify/require" ) func updateOptsLogger(opts *ChannelOpts) { if opts.Logger == nil && *connectionLog { opts.Logger = tchannel.SimpleLogger } } func updateOptsForTest(t testing.TB, opts *ChannelOpts) { updateOptsLogger(opts) // If there's no logger, then register the test logger which will record // everything to a buffer, and print out the buffer if the test fails. if opts.Logger == nil { tl := newTestLogger(t) opts.Logger = tl opts.addPostFn(tl.report) } if !opts.LogVerification.Disabled { opts.Logger = opts.LogVerification.WrapLogger(t, opts.Logger) } } // WithServer sets up a TChannel that is listening and runs the given function with the channel. func WithServer(t testing.TB, opts *ChannelOpts, f func(ch *tchannel.Channel, hostPort string)) { opts = opts.Copy() updateOptsForTest(t, opts) ch := NewServer(t, opts) f(ch, ch.PeerInfo().HostPort) ch.Close() } // NewServer returns a new TChannel server that listens on :0. func NewServer(t testing.TB, opts *ChannelOpts) *tchannel.Channel { return newServer(t, opts.Copy()) } // newServer must be passed non-nil opts that may be mutated to include // post-verification steps. func newServer(t testing.TB, opts *ChannelOpts) *tchannel.Channel { updateOptsForTest(t, opts) ch, err := NewServerChannel(opts) require.NoError(t, err, "NewServerChannel failed") return ch } // NewClient returns a new TChannel that is not listening. func NewClient(t testing.TB, opts *ChannelOpts) *tchannel.Channel { return newClient(t, opts.Copy()) } // newClient must be passed non-nil opts that may be mutated to include // post-verification steps. func newClient(t testing.TB, opts *ChannelOpts) *tchannel.Channel { updateOptsForTest(t, opts) ch, err := NewClientChannel(opts) require.NoError(t, err, "NewServerChannel failed") return ch } ================================================ FILE: testutils/conn.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "net" "testing" ) // GetClosedHostPort will return a host:port that will refuse connections. func GetClosedHostPort(t testing.TB) string { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("net.Listen failed: %v", err) return "" } if err := listener.Close(); err != nil { t.Fatalf("listener.Close failed") return "" } return listener.Addr().String() } // GetAcceptCloseHostPort returns a host:port that will accept a connection then // immediately close it. The returned function can be used to stop the listener. func GetAcceptCloseHostPort(t testing.TB) (string, func()) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("net.Listen failed: %v", err) return "", nil } go func() { for { conn, err := listener.Accept() if err != nil { return } conn.Close() } }() return listener.Addr().String(), func() { if err := listener.Close(); err != nil { t.Fatalf("listener.Close failed") } } } ================================================ FILE: testutils/counter.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "sync" "go.uber.org/atomic" ) // Decrement is the interface returned by Decrementor. type Decrement interface { // Single returns whether any more tokens are remaining. Single() bool // Multiple tries to get n tokens. It returns the actual amount of tokens // available to use. If this is 0, it means there are no tokens left. Multiple(n int) int } type decrementor struct { n atomic.Int64 } func (d *decrementor) Single() bool { return d.n.Dec() >= 0 } func (d *decrementor) Multiple(n int) int { decBy := -1 * int64(n) decremented := d.n.Add(decBy) if decremented <= decBy { // Already out of tokens before this decrement. return 0 } else if decremented < 0 { // Not enough tokens, return how many tokens we actually could decrement. return n + int(decremented) } return n } // Decrementor returns a function that can be called from multiple goroutines and ensures // it will only return true n times. func Decrementor(n int) Decrement { return &decrementor{ n: *atomic.NewInt64(int64(n)), } } // Batch returns a slice with n broken into batches of size batchSize. func Batch(n, batchSize int) []int { fullBatches := n / batchSize batches := make([]int, 0, fullBatches+1) for i := 0; i < fullBatches; i++ { batches = append(batches, batchSize) } if remaining := n % batchSize; remaining > 0 { batches = append(batches, remaining) } return batches } // Buckets splits n over the specified number of buckets. func Buckets(n int, numBuckets int) []int { perBucket := n / numBuckets buckets := make([]int, numBuckets) for i := range buckets { buckets[i] = perBucket if i == 0 { buckets[i] += n % numBuckets } } return buckets } // RunN runs the given f n times (and passes the run's index) and waits till they complete. // It starts n-1 goroutines, and runs one instance in the current goroutine. func RunN(n int, f func(i int)) { var wg sync.WaitGroup for i := 0; i < n-1; i++ { wg.Add(1) go func(i int) { defer wg.Done() f(i) }(i) } f(n - 1) wg.Wait() } ================================================ FILE: testutils/counter_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "math/rand" "testing" "github.com/stretchr/testify/assert" ) func testDecrementor(t *testing.T, f func(dec Decrement) int) { const count = 10000 const numGoroutines = 100 dec := Decrementor(count) results := make(chan int, numGoroutines) for i := 0; i < numGoroutines; i++ { go func() { results <- f(dec) }() } var total int for i := 0; i < numGoroutines; i++ { total += <-results } assert.Equal(t, count, total, "Count mismatch") } func TestDecrementSingle(t *testing.T) { testDecrementor(t, func(dec Decrement) int { count := 0 for dec.Single() { count++ } return count }) } func TestDecrementMultiple(t *testing.T) { testDecrementor(t, func(dec Decrement) int { count := 0 for { tokens := dec.Multiple(rand.Intn(100) + 1) if tokens == 0 { break } count += tokens } return count }) } func TestBatch(t *testing.T) { tests := []struct { n int batch int want []int }{ {40, 10, []int{10, 10, 10, 10}}, {5, 10, []int{5}}, {45, 10, []int{10, 10, 10, 10, 5}}, } for _, tt := range tests { got := Batch(tt.n, tt.batch) assert.Equal(t, tt.want, got, "Batch(%v, %v) unexpected result", tt.n, tt.batch) } } func TestBuckets(t *testing.T) { tests := []struct { n int buckets int want []int }{ {2, 3, []int{2, 0, 0}}, {3, 3, []int{1, 1, 1}}, {4, 3, []int{2, 1, 1}}, } for _, tt := range tests { got := Buckets(tt.n, tt.buckets) assert.Equal(t, tt.want, got, "Buckets(%v, %v) unexpected result", tt.n, tt.buckets) } } ================================================ FILE: testutils/data.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "encoding/base32" "encoding/binary" "math/rand" "sync" ) // This file contains functions for tests to access internal tchannel state. // Since it has a _test.go suffix, it is only compiled with tests in this package. var ( randCache []byte randMut sync.RWMutex ) func checkCacheSize(n int) { // Start with a reasonably large cache. if n < 1024 { n = 1024 } randMut.RLock() curSize := len(randCache) randMut.RUnlock() // The cache needs to be at least twice as large as the requested size. if curSize >= n*2 { return } resizeCache(n) } func resizeCache(n int) { randMut.Lock() defer randMut.Unlock() // Double check under the write lock if len(randCache) >= n*2 { return } newSize := (n * 2 / 8) * 8 newCache := make([]byte, newSize) copied := copy(newCache, randCache) for i := copied; i < newSize; i += 8 { n := rand.Int63() binary.BigEndian.PutUint64(newCache[i:], uint64(n)) } randCache = newCache } // RandBytes returns n random byte slice that points to a shared random byte array. // Since the underlying random array is shared, the returned byte slice must NOT be modified. func RandBytes(n int) []byte { const maxSize = 2 * 1024 * 1024 data := make([]byte, 0, n) for i := 0; i < n; i += maxSize { s := n - i if s > maxSize { s = maxSize } data = append(data, randBytes(s)...) } return data } // RandString returns a random alphanumeric string for testing. func RandString(n int) string { encoding := base32.StdEncoding numBytes := encoding.DecodedLen(n) + 5 return base32.StdEncoding.EncodeToString(RandBytes(numBytes))[:n] } func randBytes(n int) []byte { checkCacheSize(n) randMut.RLock() startAt := rand.Intn(len(randCache) - n) bs := randCache[startAt : startAt+n] randMut.RUnlock() return bs } ================================================ FILE: testutils/echo.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) const ( _defaultTimeout = 300 * time.Millisecond ) // CallEcho calls the "echo" endpoint from the given src to target. func CallEcho( src *tchannel.Channel, targetHostPort string, targetService string, args *raw.Args, ) error { return CallEchoWithContext( context.Background(), src, targetHostPort, targetService, args, ) } // CallEchoWithContext calls the "echo" endpoint from the given src to target, // using any deadline within the given context.Context. func CallEchoWithContext( ctx context.Context, src *tchannel.Channel, targetHostPort string, targetService string, args *raw.Args, ) error { if args == nil { args = &raw.Args{} } timeout := _defaultTimeout dl, ok := ctx.Deadline() if ok { timeout = time.Until(dl) } ctx, cancel := tchannel.NewContextBuilder(Timeout(timeout)). SetConnectBaseContext(ctx). SetFormat(args.Format). Build() defer cancel() _, _, _, err := raw.Call( ctx, src, targetHostPort, targetService, "echo", args.Arg2, args.Arg3, ) return err } // AssertEcho calls the "echo" endpoint with random data, and asserts // that the returned data matches the arguments "echo" was called with. func AssertEcho(tb testing.TB, src *tchannel.Channel, targetHostPort, targetService string) { ctx, cancel := tchannel.NewContext(Timeout(_defaultTimeout)) defer cancel() args := &raw.Args{ Arg2: RandBytes(1000), Arg3: RandBytes(1000), } arg2, arg3, _, err := raw.Call(ctx, src, targetHostPort, targetService, "echo", args.Arg2, args.Arg3) if !assert.NoError(tb, err, "Call from %v (%v) to %v (%v) failed", src.ServiceName(), src.PeerInfo().HostPort, targetService, targetHostPort) { return } assert.Equal(tb, args.Arg2, arg2, "Arg2 mismatch") assert.Equal(tb, args.Arg3, arg3, "Arg3 mismatch") } // RegisterEcho registers an echo endpoint on the given channel. The optional provided // function is run before the handler returns. func RegisterEcho(src tchannel.Registrar, f func()) { RegisterFunc(src, "echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { if f != nil { f() } return &raw.Res{Arg2: args.Arg2, Arg3: args.Arg3}, nil }) } // Ping sends a ping from src to target. func Ping(src, target *tchannel.Channel) error { ctx, cancel := tchannel.NewContext(Timeout(_defaultTimeout)) defer cancel() return src.Ping(ctx, target.PeerInfo().HostPort) } ================================================ FILE: testutils/goroutines/stacks.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package goroutines import ( "bufio" "bytes" "fmt" "io" "runtime" "strconv" "strings" ) // Stack represents a single Goroutine's stack. type Stack struct { id int state string firstFunction string fullStack *bytes.Buffer } // ID returns the goroutine ID. func (s Stack) ID() int { return s.id } // State returns the Goroutine's state. func (s Stack) State() string { return s.state } // Full returns the full stack trace for this goroutine. func (s Stack) Full() []byte { return s.fullStack.Bytes() } func (s Stack) String() string { return fmt.Sprintf( "Goroutine %v in state %v, with %v on top of the stack:\n%s", s.id, s.state, s.firstFunction, s.Full()) } func getStacks(all bool) []Stack { var stacks []Stack var curStack *Stack stackReader := bufio.NewReader(bytes.NewReader(getStackBuffer(all))) for { line, err := stackReader.ReadString('\n') if err == io.EOF { break } if err != nil { panic("stack reader failed") } // If we see the goroutine header, start a new stack. isFirstLine := false if strings.HasPrefix(line, "goroutine ") { // flush any previous stack if curStack != nil { stacks = append(stacks, *curStack) } id, goState := parseGoStackHeader(line) curStack = &Stack{ id: id, state: goState, fullStack: &bytes.Buffer{}, } isFirstLine = true } curStack.fullStack.WriteString(line) if !isFirstLine && curStack.firstFunction == "" { curStack.firstFunction = parseFirstFunc(line) } } if curStack != nil { stacks = append(stacks, *curStack) } return stacks } // GetAll returns the stacks for all running goroutines. func GetAll() []Stack { return getStacks(true) } // GetCurrentStack returns the stack for the current goroutine. func GetCurrentStack() Stack { return getStacks(false)[0] } func getStackBuffer(all bool) []byte { for i := 4096; ; i *= 2 { buf := make([]byte, i) if n := runtime.Stack(buf, all); n < i { return buf } } } func parseFirstFunc(line string) string { line = strings.TrimSpace(line) if idx := strings.LastIndex(line, "("); idx > 0 { return line[:idx] } return line } // parseGoStackHeader parses a stack header that looks like: // goroutine 643 [runnable]:\n // And returns the goroutine ID, and the state. func parseGoStackHeader(line string) (goroutineID int, state string) { line = strings.TrimSuffix(line, ":\n") parts := strings.SplitN(line, " ", 3) if len(parts) != 3 { panic(fmt.Sprintf("unexpected stack header format: %v", line)) } id, err := strconv.Atoi(parts[1]) if err != nil { panic(fmt.Sprintf("failed to parse goroutine ID: %v", parts[1])) } state = strings.TrimSuffix(strings.TrimPrefix(parts[2], "["), "]") return id, state } ================================================ FILE: testutils/goroutines/verify.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package goroutines import ( "fmt" "runtime" "strings" "testing" "time" ) // filterStacks will filter any stacks excluded by the given VerifyOpts. func filterStacks(stacks []Stack, skipID int, opts *VerifyOpts) []Stack { filtered := stacks[:0] for _, stack := range stacks { if stack.ID() == skipID || shouldIgnore(stack) { continue } if opts.ShouldSkip(stack) { continue } filtered = append(filtered, stack) } return filtered } func shouldIgnore(s Stack) bool { switch funcName := s.firstFunction; funcName { case "testing.RunTests", "testing.(*T).Run": return strings.HasPrefix(s.State(), "chan receive") case "runtime.goexit": return strings.HasPrefix(s.State(), "syscall") case "os/signal.signal_recv": // The signal package automatically starts a goroutine when it's imported. return true default: return false } } // IdentifyLeaks looks for extra goroutines, and returns a descriptive error if // it finds any. func IdentifyLeaks(opts *VerifyOpts) error { cur := GetCurrentStack().id const maxAttempts = 50 var stacks []Stack for i := 0; i < maxAttempts; i++ { stacks = GetAll() stacks = filterStacks(stacks, cur, opts) if len(stacks) == 0 { return nil } if i > maxAttempts/2 { time.Sleep(time.Duration(i) * time.Millisecond) } else { runtime.Gosched() } } return fmt.Errorf("found unexpected goroutines:\n%s", stacks) } // VerifyNoLeaks calls IdentifyLeaks and fails the test if it finds any leaked // goroutines. func VerifyNoLeaks(t testing.TB, opts *VerifyOpts) { if err := IdentifyLeaks(opts); err != nil { t.Error(err.Error()) } } ================================================ FILE: testutils/goroutines/verify_opts.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package goroutines import "bytes" // VerifyOpts contains type VerifyOpts struct { // Excludes is a list of strings that will exclude a stack from being considered a leak. Excludes []string } // ShouldSkip returns whether the given stack should be skipped when doing verification. func (opts *VerifyOpts) ShouldSkip(s Stack) bool { if opts == nil || len(opts.Excludes) == 0 { return false } for _, exclude := range opts.Excludes { if bytes.Contains(s.Full(), []byte(exclude)) { return true } } return false } ================================================ FILE: testutils/lists.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import "time" // StrArray will return an array with the given strings. func StrArray(ss ...string) []string { return ss } // StrMap returns a map where the keys are the given strings. func StrMap(ss ...string) map[string]struct{} { m := make(map[string]struct{}, len(ss)) for _, v := range ss { m[v] = struct{}{} } return m } // DurationArray returns an array with the given durations. func DurationArray(dd ...time.Duration) []time.Duration { return dd } ================================================ FILE: testutils/logfilter_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "testing" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go" ) func TestLogFilterMatches(t *testing.T) { msgFilter := LogFilter{ Filter: "msgFilter", } fieldsFilter := LogFilter{ Filter: "msgFilter", FieldFilters: map[string]string{ "f1": "v1", "f2": "v2", }, } // fields takes a varargs list of strings which it reads as: // key, value, key, value... fields := func(vals ...string) []tchannel.LogField { fs := make([]tchannel.LogField, len(vals)/2) for i := 0; i < len(vals); i += 2 { fs[i/2] = tchannel.LogField{ Key: vals[i], Value: vals[i+1], } } return fs } tests := []struct { Filter LogFilter Message string Fields []tchannel.LogField Match bool }{ { Filter: msgFilter, Message: "random message", Match: false, }, { Filter: msgFilter, Message: "msgFilter", Match: true, }, { // Case matters. Filter: msgFilter, Message: "msgfilter", Match: false, }, { Filter: msgFilter, Message: "abc msgFilterdef", Match: true, }, { Filter: fieldsFilter, Message: "random message", Fields: fields("f1", "v1", "f2", "v2"), Match: false, }, { Filter: fieldsFilter, Message: "msgFilter", Fields: fields("f1", "v1", "f2", "v2"), Match: true, }, { // Field mismatch should not match. Filter: fieldsFilter, Message: "msgFilter", Fields: fields("f1", "v0", "f2", "v2"), Match: false, }, { // Missing field should not match. Filter: fieldsFilter, Message: "msgFilter", Fields: fields("f2", "v2"), Match: false, }, { // Extra fields are OK. Filter: fieldsFilter, Message: "msgFilter", Fields: fields("f1", "v0", "f2", "v2", "f3", "v3"), Match: false, }, } for _, tt := range tests { got := tt.Filter.Matches(tt.Message, tt.Fields) assert.Equal(t, tt.Match, got, "Filter %+v .Matches(%v, %v) mismatch", tt.Filter, tt.Message, tt.Fields) } } ================================================ FILE: testutils/logger.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "bytes" "fmt" "os" "strings" "sync" "testing" "time" "github.com/uber/tchannel-go" "go.uber.org/atomic" ) // writer is shared between multiple loggers, and serializes acccesses to // the underlying buffer. type writer struct { sync.Mutex buf *bytes.Buffer } // testLogger is a logger that writes all output to a buffer, and can report // the logs if the test has failed. type testLogger struct { t testing.TB fields tchannel.LogFields w *writer } type errorLoggerState struct { matchCount []atomic.Uint32 } type errorLogger struct { tchannel.Logger t testing.TB v *LogVerification s *errorLoggerState } func newWriter() *writer { return &writer{buf: &bytes.Buffer{}} } func (w *writer) withLock(f func(*bytes.Buffer)) { w.Lock() f(w.buf) w.Unlock() } // Matches returns true if the message and fields match the filter. func (f LogFilter) Matches(msg string, fields tchannel.LogFields) bool { // First check the message and ensure it contains Filter if !strings.Contains(msg, f.Filter) { return false } // if there are no field filters, then the message match is enough. if len(f.FieldFilters) == 0 { return true } fieldsMap := make(map[string]interface{}) for _, field := range fields { fieldsMap[field.Key] = field.Value } for k, filter := range f.FieldFilters { value, ok := fieldsMap[k] if !ok { return false } if !strings.Contains(fmt.Sprint(value), filter) { return false } } return true } func newTestLogger(t testing.TB) testLogger { return testLogger{t, nil, newWriter()} } func (l testLogger) Enabled(level tchannel.LogLevel) bool { return true } func (l testLogger) log(prefix string, msg string) { logLine := fmt.Sprintf("%s [%v] %v %v\n", time.Now().Format("15:04:05.000000"), prefix, msg, l.Fields()) l.w.withLock(func(w *bytes.Buffer) { w.WriteString(logLine) }) } func (l testLogger) Fatal(msg string) { l.log("F", msg) } func (l testLogger) Error(msg string) { l.log("E", msg) } func (l testLogger) Warn(msg string) { l.log("W", msg) } func (l testLogger) Info(msg string) { l.log("I", msg) } func (l testLogger) Infof(msg string, args ...interface{}) { l.log("I", fmt.Sprintf(msg, args...)) } func (l testLogger) Debug(msg string) { l.log("D", msg) } func (l testLogger) Debugf(msg string, args ...interface{}) { l.log("D", fmt.Sprintf(msg, args...)) } func (l testLogger) Fields() tchannel.LogFields { return l.fields } func (l testLogger) WithFields(fields ...tchannel.LogField) tchannel.Logger { existing := len(l.Fields()) newFields := make(tchannel.LogFields, existing+len(fields)) copy(newFields, l.Fields()) copy(newFields[existing:], fields) return testLogger{l.t, newFields, l.w} } func (l testLogger) report() { if os.Getenv("LOGS_ON_FAILURE") == "" { return } if l.t.Failed() { l.w.withLock(func(w *bytes.Buffer) { l.t.Logf("Debug logs:\n%s", w.String()) }) } } // checkFilters returns whether the message can be ignored by the filters. func (l errorLogger) checkFilters(msg string) bool { match := -1 for i, filter := range l.v.Filters { if filter.Matches(msg, l.Fields()) { match = i } } if match == -1 { return false } matchCount := l.s.matchCount[match].Inc() return uint(matchCount) <= l.v.Filters[match].Count } func (l errorLogger) checkErr(prefix, msg string) { if l.checkFilters(msg) { return } l.t.Errorf("Unexpected log: %v: %s %v", prefix, msg, l.Logger.Fields()) } func (l errorLogger) Fatal(msg string) { l.checkErr("[Fatal]", msg) l.Logger.Fatal(msg) } func (l errorLogger) Error(msg string) { l.checkErr("[Error]", msg) l.Logger.Error(msg) } func (l errorLogger) Warn(msg string) { l.checkErr("[Warn]", msg) l.Logger.Warn(msg) } func (l errorLogger) WithFields(fields ...tchannel.LogField) tchannel.Logger { return errorLogger{l.Logger.WithFields(fields...), l.t, l.v, l.s} } ================================================ FILE: testutils/mockhyperbahn/hyperbahn.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package mockhyperbahn import ( "errors" "fmt" "sync" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/hyperbahn" hthrift "github.com/uber/tchannel-go/hyperbahn/gen-go/hyperbahn" "github.com/uber/tchannel-go/json" "github.com/uber/tchannel-go/relay/relaytest" "github.com/uber/tchannel-go/thrift" ) // Mock is up a mock Hyperbahn server for tests. type Mock struct { sync.RWMutex ch *tchannel.Channel respCh chan int advertised []string discoverResults map[string][]string } // New returns a mock Hyperbahn server that can be used for testing. func New() (*Mock, error) { stubHost := relaytest.NewStubRelayHost() ch, err := tchannel.NewChannel("hyperbahn", &tchannel.ChannelOptions{ RelayHost: stubHost, RelayLocalHandlers: []string{"hyperbahn"}, }) if err != nil { return nil, err } mh := &Mock{ ch: ch, respCh: make(chan int), discoverResults: make(map[string][]string), } if err := json.Register(ch, json.Handlers{"ad": mh.adHandler}, nil); err != nil { return nil, err } thriftServer := thrift.NewServer(ch) thriftServer.Register(hthrift.NewTChanHyperbahnServer(mh)) return mh, ch.ListenAndServe("127.0.0.1:0") } // SetDiscoverResult sets the given hostPorts as results for the Discover call. func (h *Mock) SetDiscoverResult(serviceName string, hostPorts []string) { h.Lock() defer h.Unlock() h.discoverResults[serviceName] = hostPorts } // Discover returns the IPs for a discovery query if some were set using SetDiscoverResult. // Otherwise, it returns an error. func (h *Mock) Discover(ctx thrift.Context, query *hthrift.DiscoveryQuery) (*hthrift.DiscoveryResult_, error) { h.RLock() defer h.RUnlock() hostPorts, ok := h.discoverResults[query.ServiceName] if !ok { return nil, fmt.Errorf("no discovery results set for %v", query.ServiceName) } peers, err := toServicePeers(hostPorts) if err != nil { return nil, fmt.Errorf("invalid discover result set: %v", err) } return &hthrift.DiscoveryResult_{ Peers: peers, }, nil } // Configuration returns a hyperbahn.Configuration object used to configure a // hyperbahn.Client to talk to this mock server. func (h *Mock) Configuration() hyperbahn.Configuration { return hyperbahn.Configuration{ InitialNodes: []string{h.ch.PeerInfo().HostPort}, } } // Channel returns the underlying tchannel that implements relaying. func (h *Mock) Channel() *tchannel.Channel { return h.ch } func (h *Mock) adHandler(ctx json.Context, req *hyperbahn.AdRequest) (*hyperbahn.AdResponse, error) { callerHostPort := tchannel.CurrentCall(ctx).RemotePeer().HostPort h.Lock() for _, s := range req.Services { h.advertised = append(h.advertised, s.Name) sc := h.ch.GetSubChannel(s.Name, tchannel.Isolated) sc.Peers().Add(callerHostPort) } h.Unlock() select { case n := <-h.respCh: if n == 0 { return nil, errors.New("error") } return &hyperbahn.AdResponse{ConnectionCount: n}, nil default: // Return a default response return &hyperbahn.AdResponse{ConnectionCount: 3}, nil } } // GetAdvertised returns the list of services registered. func (h *Mock) GetAdvertised() []string { h.RLock() defer h.RUnlock() return h.advertised } // Close stops the mock Hyperbahn server. func (h *Mock) Close() { h.ch.Close() } // QueueError queues an error to be returned on the next advertise call. func (h *Mock) QueueError() { h.respCh <- 0 } // QueueResponse queues a response from Hyperbahn. // numConnections must be greater than 0. func (h *Mock) QueueResponse(numConnections int) { if numConnections <= 0 { panic("QueueResponse must have numConnections > 0") } h.respCh <- numConnections } ================================================ FILE: testutils/mockhyperbahn/hyperbahn_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package mockhyperbahn_test import ( "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/hyperbahn" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/mockhyperbahn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" ) var config = struct { hyperbahnConfig hyperbahn.Configuration }{} // setupServer is the application code we are attempting to test. func setupServer() (*hyperbahn.Client, error) { ch, err := tchannel.NewChannel("myservice", nil) if err != nil { return nil, err } if err := ch.ListenAndServe("127.0.0.1:0"); err != nil { return nil, err } client, err := hyperbahn.NewClient(ch, config.hyperbahnConfig, nil) if err != nil { return nil, err } return client, client.Advertise() } func newAdvertisedEchoServer(t *testing.T, name string, mockHB *mockhyperbahn.Mock, f func()) *tchannel.Channel { server := testutils.NewServer(t, &testutils.ChannelOpts{ ServiceName: name, }) testutils.RegisterEcho(server, f) hbClient, err := hyperbahn.NewClient(server, mockHB.Configuration(), nil) require.NoError(t, err, "Failed to set up Hyperbahn client") require.NoError(t, hbClient.Advertise(), "Advertise failed") return server } func TestMockHyperbahn(t *testing.T) { mh, err := mockhyperbahn.New() require.NoError(t, err, "mock hyperbahn failed") defer mh.Close() config.hyperbahnConfig = mh.Configuration() _, err = setupServer() require.NoError(t, err, "setupServer failed") assert.Equal(t, []string{"myservice"}, mh.GetAdvertised()) } func TestMockDiscovery(t *testing.T) { mh, err := mockhyperbahn.New() require.NoError(t, err, "mock hyperbahn failed") defer mh.Close() peers := []string{ "1.3.5.7:1456", "255.255.255.255:25", } mh.SetDiscoverResult("discover-svc", peers) config.hyperbahnConfig = mh.Configuration() client, err := setupServer() require.NoError(t, err, "setupServer failed") gotPeers, err := client.Discover("discover-svc") require.NoError(t, err, "Discover failed") assert.Equal(t, peers, gotPeers, "Discover returned invalid peers") } func TestMockForwards(t *testing.T) { mockHB, err := mockhyperbahn.New() require.NoError(t, err, "Failed to set up mock hyperbahm") called := false server := newAdvertisedEchoServer(t, "svr", mockHB, func() { called = true }) defer server.Close() client := newAdvertisedEchoServer(t, "client", mockHB, nil) defer client.Close() ctx, cancel := tchannel.NewContext(time.Second) defer cancel() _, _, _, err = raw.CallSC(ctx, client.GetSubChannel("svr"), "echo", nil, nil) require.NoError(t, err, "Call failed") require.True(t, called, "Advertised server was not called") } func TestMockIgnoresDown(t *testing.T) { mockHB, err := mockhyperbahn.New() require.NoError(t, err, "Failed to set up mock hyperbahm") var ( moe1Called atomic.Bool moe2Called atomic.Bool ) moe1 := newAdvertisedEchoServer(t, "moe", mockHB, func() { moe1Called.Store(true) }) defer moe1.Close() moe2 := newAdvertisedEchoServer(t, "moe", mockHB, func() { moe2Called.Store(true) }) defer moe2.Close() client := newAdvertisedEchoServer(t, "client", mockHB, nil) ctx, cancel := tchannel.NewContext(time.Second) defer cancel() for i := 0; i < 20; i++ { _, _, _, err = raw.CallSC(ctx, client.GetSubChannel("moe"), "echo", nil, nil) assert.NoError(t, err, "Call failed") } require.True(t, moe1Called.Load(), "moe1 not called") require.True(t, moe2Called.Load(), "moe2 not called") // If moe2 is brought down, all calls should now be sent to moe1. moe2.Close() // Wait for the mock HB to have 0 connections to moe ok := testutils.WaitFor(time.Second, func() bool { in, out := mockHB.Channel().Peers().GetOrAdd(moe2.PeerInfo().HostPort).NumConnections() return in+out == 0 }) require.True(t, ok, "Failed waiting for mock HB to have 0 connections") // Make sure that all calls succeed (they should all go to moe2) moe1Called.Store(false) moe2Called.Store(false) for i := 0; i < 20; i++ { _, _, _, err = raw.CallSC(ctx, client.GetSubChannel("moe"), "echo", nil, nil) assert.NoError(t, err, "Call failed") } require.True(t, moe1Called.Load(), "moe1 not called") require.False(t, moe2Called.Load(), "moe2 should not be called after Close") } ================================================ FILE: testutils/mockhyperbahn/utils.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package mockhyperbahn import ( "fmt" "net" "strconv" hthrift "github.com/uber/tchannel-go/hyperbahn/gen-go/hyperbahn" ) func toServicePeer(hostPort string) (*hthrift.ServicePeer, error) { host, port, err := net.SplitHostPort(hostPort) if err != nil { return nil, fmt.Errorf("invalid hostPort %v: %v", hostPort, err) } ip := net.ParseIP(host) if ip == nil { return nil, fmt.Errorf("host %v is not an ip", host) } ip = ip.To4() if len(ip) != net.IPv4len { return nil, fmt.Errorf("ip %v is not a v4 ip, expected length to be %v, got %v", host, net.IPv4len, len(ip)) } portInt, err := strconv.Atoi(port) if err != nil { return nil, fmt.Errorf("invalid port %v: %v", port, err) } // We have 4 bytes for the IP, use that as an int. ipInt := int32(uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])) return &hthrift.ServicePeer{ IP: &hthrift.IpAddress{Ipv4: &ipInt}, Port: int32(portInt), }, nil } func toServicePeers(hostPorts []string) ([]*hthrift.ServicePeer, error) { var peers []*hthrift.ServicePeer for _, hostPort := range hostPorts { peer, err := toServicePeer(hostPort) if err != nil { return nil, err } peers = append(peers, peer) } return peers, nil } ================================================ FILE: testutils/now.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "sync" "time" ) // StubClock is a fake wall-clock, exposing a Now() method that returns a // test-controlled time. type StubClock struct { mu sync.Mutex cur time.Time } // NewStubClock returns a fake wall-clock object func NewStubClock(initial time.Time) *StubClock { return &StubClock{ cur: initial, } } // Now returns the current time stored in StubClock func (c *StubClock) Now() time.Time { c.mu.Lock() defer c.mu.Unlock() return c.cur } // Elapse increments the time returned by Now() func (c *StubClock) Elapse(addAmt time.Duration) { c.mu.Lock() defer c.mu.Unlock() c.cur = c.cur.Add(addAmt) } ================================================ FILE: testutils/random_bench_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "bytes" "io" "io/ioutil" "testing" ) func benchmarkRandom(b *testing.B, numBytes int) { var bs []byte for i := 0; i < b.N; i++ { randCache = nil bs = RandBytes(numBytes) } io.Copy(ioutil.Discard, bytes.NewReader(bs)) } func BenchmarkRandom256(b *testing.B) { benchmarkRandom(b, 256) } func BenchmarkRandom1024(b *testing.B) { benchmarkRandom(b, 1024) } func BenchmarkRandom4096(b *testing.B) { benchmarkRandom(b, 4096) } func BenchmarkRandom16384(b *testing.B) { benchmarkRandom(b, 16384) } func BenchmarkRandom32768(b *testing.B) { benchmarkRandom(b, 32768) } ================================================ FILE: testutils/relay.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "io" "net" "sync" "testing" "github.com/uber/tchannel-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" ) type frameRelay struct { sync.Mutex // protects conns t testing.TB destination string relayFunc func(outgoing bool, f *tchannel.Frame) *tchannel.Frame closed atomic.Uint32 conns []net.Conn wg sync.WaitGroup } func (r *frameRelay) listen() (listenHostPort string, cancel func()) { conn, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(r.t, err, "net.Listen failed") go func() { for { c, err := conn.Accept() if err != nil { if r.closed.Load() == 0 { r.t.Errorf("Accept failed: %v", err) } return } r.Lock() r.conns = append(r.conns, c) r.Unlock() r.relayConn(c) } }() return conn.Addr().String(), func() { r.closed.Inc() conn.Close() r.Lock() for _, c := range r.conns { c.Close() } r.Unlock() // Wait for all the outbound connections we created to close. r.wg.Wait() } } func (r *frameRelay) relayConn(c net.Conn) { outC, err := net.Dial("tcp", r.destination) if !assert.NoError(r.t, err, "relay connection failed") { return } r.Lock() defer r.Unlock() if r.closed.Load() > 0 { outC.Close() return } r.conns = append(r.conns, outC) r.wg.Add(2) go r.relayBetween(true /* outgoing */, c, outC) go r.relayBetween(false /* outgoing */, outC, c) } func (r *frameRelay) relayBetween(outgoing bool, c net.Conn, outC net.Conn) { defer r.wg.Done() frame := tchannel.NewFrame(tchannel.MaxFramePayloadSize) for { err := frame.ReadIn(c) if err == io.EOF { // Connection gracefully closed. return } if err != nil && r.closed.Load() > 0 { // Once the relay is shutdown, we expect connection errors. return } if !assert.NoError(r.t, err, "read frame failed") { return } outFrame := r.relayFunc(outgoing, frame) if outFrame == nil { continue } err = outFrame.WriteOut(outC) if err != nil && r.closed.Load() > 0 { // Once the relay is shutdown, we expect connection errors. return } if !assert.NoError(r.t, err, "write frame failed") { return } } } // FrameRelay sets up a relay that can modify frames using relayFunc. func FrameRelay(t testing.TB, destination string, relayFunc func(outgoing bool, f *tchannel.Frame) *tchannel.Frame) (listenHostPort string, cancel func()) { relay := &frameRelay{ t: t, destination: destination, relayFunc: relayFunc, } return relay.listen() } ================================================ FILE: testutils/sleep.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import "time" // SleepStub stubs a function variable that points to time.Sleep. It returns // two channels to control the sleep stub, and a function to close the channels. // Once the stub is closed, any further sleeps will cause panics. // The two channels returned are: // <-chan time.Duration which will contain arguments that the stub was called with. // chan<- struct{} that should be written to when you want the Sleep to return. func SleepStub(funcVar *func(time.Duration)) ( argCh <-chan time.Duration, unblockCh chan<- struct{}, closeFn func()) { args := make(chan time.Duration) block := make(chan struct{}) *funcVar = func(t time.Duration) { args <- t <-block } closeSleepChans := func() { close(args) close(block) } return args, block, closeSleepChans } // ResetSleepStub resets a Sleep stub. func ResetSleepStub(funcVar *func(time.Duration)) { *funcVar = time.Sleep } ================================================ FILE: testutils/test_server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "fmt" "os" "strings" "sync" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/relay/relaytest" "github.com/uber/tchannel-go/testutils/goroutines" "go.uber.org/multierr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" "golang.org/x/net/context" ) // Has a previous test already leaked a goroutine? var _leakedGoroutine atomic.Bool // A TestServer encapsulates a TChannel server, a client factory, and functions // to ensure that we're not leaking resources. type TestServer struct { testing.TB // References to specific channels (if any, as they can be disabled) relayCh *tchannel.Channel serverCh *tchannel.Channel // relayHost is the relayer's StubRelayHost (if any). relayHost *relaytest.StubRelayHost // relayStats is the backing stats for the relay. // Note: if a user passes a custom RelayHosts that does not implement // relayStatter, then this will be nil, and relay stats cannot be verified. relayStats *relaytest.MockStats // channels is the list of channels created for this TestServer. The first // element is always the initial server. channels []*tchannel.Channel // channelState the initial runtime state for all channels created // as part of the TestServer (including the server). channelStates map[*tchannel.Channel]*tchannel.RuntimeState introspectOpts *tchannel.IntrospectionOptions verifyOpts *goroutines.VerifyOpts postFns []func() } type relayStatter interface { Stats() *relaytest.MockStats } // NewTestServer constructs a TestServer. func NewTestServer(t testing.TB, opts *ChannelOpts) *TestServer { ts := &TestServer{ TB: t, channelStates: make(map[*tchannel.Channel]*tchannel.RuntimeState), introspectOpts: &tchannel.IntrospectionOptions{ IncludeExchanges: true, IncludeTombstones: true, }, } if !opts.DisableServer { // Remove any relay options, since those should only be applied to addRelay. serverOpts := opts.Copy() serverOpts.RelayHost = nil ts.serverCh = ts.NewServer(serverOpts) } if opts == nil || !opts.DisableRelay { ts.addRelay(opts) } return ts } // runSubTest runs the specified function as a sub-test of a testing.T or // testing.B if the types match. func runSubTest(t testing.TB, name string, f func(testing.TB)) { switch t := t.(type) { case *testing.T: t.Run(name, func(t *testing.T) { f(t) }) case *testing.B: t.Run(name, func(b *testing.B) { f(b) }) default: f(t) } } // WithTestServer creates a new TestServer, runs the passed function, and then // verifies that no resources were leaked. func WithTestServer(t testing.TB, chanOpts *ChannelOpts, f func(testing.TB, *TestServer)) { runTest := func(t testing.TB, chanOpts *ChannelOpts) { runCount := chanOpts.RunCount if runCount < 1 { runCount = 1 } for i := 0; i < runCount; i++ { if t.Failed() { return } // Run without the relay, unless OnlyRelay was set. if !chanOpts.OnlyRelay { runSubTest(t, "no relay", func(t testing.TB) { noRelayOpts := chanOpts.Copy() noRelayOpts.DisableRelay = true withServer(t, noRelayOpts, f) }) } // Run with the relay, unless the user has disabled it. if !chanOpts.DisableRelay { runSubTest(t, "with relay", func(t testing.TB) { withServer(t, chanOpts.Copy(), f) }) // Re-run the same test with timer verification if this is a relay-only test. if chanOpts.OnlyRelay { runSubTest(t, "with relay and timer verification", func(t testing.TB) { verifyOpts := chanOpts.Copy() verifyOpts.RelayTimerVerification = true withServer(t, verifyOpts, f) }) } } } } chanOptsCopy := chanOpts.Copy() runTest(t, chanOptsCopy) if os.Getenv("DISABLE_FRAME_POOLING_CHECKS") == "" && chanOptsCopy.CheckFramePooling { runSubTest(t, "check frame leaks", func(t testing.TB) { pool := tchannel.NewCheckedFramePoolForTest() runTest(t, chanOpts.Copy().SetFramePool(pool)) result := pool.CheckEmpty() if len(result.Unreleased) > 0 { t.Errorf("Frame pool has %v unreleased frames, errors:\n%v\n", len(result.Unreleased), strings.Join(result.Unreleased, "\n")) } if len(result.BadReleases) > 0 { t.Errorf("Frame pool has %v bad releases, errors:\n%v\n", len(result.BadReleases), strings.Join(result.BadReleases, "\n")) } }) } } // SetVerifyOpts specifies the options we'll use during teardown to verify that // no goroutines were leaked. func (ts *TestServer) SetVerifyOpts(opts *goroutines.VerifyOpts) { ts.verifyOpts = opts } // HasServer returns whether this TestServer has a TChannel server, as // the server may have been disabled with the DisableServer option. func (ts *TestServer) HasServer() bool { return ts.serverCh != nil } // Server returns the underlying TChannel for the server (i.e., the channel on // which we're registering handlers). // // To support test cases with relays interposed between clients and servers, // callers should use the Client(), HostPort(), ServiceName(), and Register() // methods instead of accessing the server channel explicitly. func (ts *TestServer) Server() *tchannel.Channel { require.True(ts, ts.HasServer(), "Cannot use Server as it was disabled") return ts.serverCh } // HasRelay indicates whether this TestServer has a relay interposed between the // server and clients. func (ts *TestServer) HasRelay() bool { return ts.relayCh != nil } // Relay returns the relay channel, if one is present. func (ts *TestServer) Relay() *tchannel.Channel { require.True(ts, ts.HasRelay(), "Cannot use Relay, not present in current test") return ts.relayCh } // RelayHost returns the stub RelayHost for mapping service names to peers. func (ts *TestServer) RelayHost() *relaytest.StubRelayHost { return ts.relayHost } // HostPort returns the host:port for clients to connect to. Note that this may // not be the same as the host:port of the server channel. func (ts *TestServer) HostPort() string { if ts.HasRelay() { return ts.Relay().PeerInfo().HostPort } return ts.Server().PeerInfo().HostPort } // ServiceName returns the service name of the server channel. func (ts *TestServer) ServiceName() string { return ts.Server().PeerInfo().ServiceName } // Register registers a handler on the server channel. func (ts *TestServer) Register(h tchannel.Handler, methodName string) { ts.Server().Register(h, methodName) } // RegisterFunc registers a function as a handler for the given method name. // // TODO: Delete testutils.RegisterFunc in favor of this test server. func (ts *TestServer) RegisterFunc(name string, f func(context.Context, *raw.Args) (*raw.Res, error)) { ts.Register(raw.Wrap(rawFuncHandler{ts.Server(), f}), name) } // CloseAndVerify closes all channels verifying each channel as it is closed. // It then verifies that no goroutines were leaked. func (ts *TestServer) CloseAndVerify() { // Verify channels before they are closed to ensure that we catch any // unexpected pending exchanges. var verify sync.WaitGroup for i := len(ts.channels) - 1; i >= 0; i-- { ch := ts.channels[i] verify.Add(1) go func() { defer verify.Done() ch.Logger().Debugf("TEST: TestServer is verifying channel") ts.verify(ch) }() } verify.Wait() // Close the connection, then verify again to ensure connection close didn't // cause any unexpected issues. var closeVerify sync.WaitGroup for i := len(ts.channels) - 1; i >= 0; i-- { ch := ts.channels[i] closeVerify.Add(1) go func() { defer closeVerify.Done() ch.Logger().Debugf("TEST: TestServer is closing and verifying channel") ts.close(ch) ts.verify(ch) }() } closeVerify.Wait() if ts.relayCh != nil { ts.close(ts.relayCh) ts.verify(ts.relayCh) } // Verify that there's no goroutine leaks after all tests are complete. ts.verifyNoGoroutinesLeaked() } // AssertRelayStats checks that the relayed call graph matches expectations. If // there's no relay, AssertRelayStats is a no-op. func (ts *TestServer) AssertRelayStats(expected *relaytest.MockStats) { if !ts.HasRelay() { return } if ts.relayStats == nil { ts.TB.Error("Cannot verify relay stats, passed in RelayStats does not implement relayStatter") return } ts.relayStats.AssertEqual(ts, expected) } // NewClient returns a client that with log verification. // TODO: Verify message exchanges and leaks for client channels as well. func (ts *TestServer) NewClient(opts *ChannelOpts) *tchannel.Channel { return ts.addChannel(newClient, opts.Copy()) } // NewServer returns a server with log and channel state verification. // // Note: The same default service name is used if one isn't specified. func (ts *TestServer) NewServer(opts *ChannelOpts) *tchannel.Channel { ch := ts.addChannel(newServer, opts.Copy()) if ts.relayHost != nil { ts.relayHost.Add(ch.ServiceName(), ch.PeerInfo().HostPort) } return ch } // addRelay adds a relay in front of the test server, altering public methods as // necessary to route traffic through the relay. func (ts *TestServer) addRelay(parentOpts *ChannelOpts) { opts := parentOpts.Copy() relayHost := opts.ChannelOptions.RelayHost if relayHost == nil { ts.relayHost = relaytest.NewStubRelayHost() relayHost = ts.relayHost } else if relayHost, ok := relayHost.(*relaytest.StubRelayHost); ok { ts.relayHost = relayHost } opts.ServiceName = "relay" opts.ChannelOptions.RelayHost = relayHost ts.relayCh = ts.addChannel(newServer, opts) if ts.relayHost != nil && ts.HasServer() { ts.relayHost.Add(ts.Server().ServiceName(), ts.Server().PeerInfo().HostPort) } if statter, ok := relayHost.(relayStatter); ok { ts.relayStats = statter.Stats() } } func (ts *TestServer) addChannel(createChannel func(t testing.TB, opts *ChannelOpts) *tchannel.Channel, opts *ChannelOpts) *tchannel.Channel { ch := createChannel(ts, opts) ts.postFns = append(ts.postFns, opts.postFns...) ts.channels = append(ts.channels, ch) ts.channelStates[ch] = comparableState(ch, ts.introspectOpts) return ch } // close closes all channels in most-recently-created order. // it waits for the channels to close. func (ts *TestServer) close(ch *tchannel.Channel) { ch.Close() timeout := Timeout(time.Second) select { case <-time.After(timeout): ts.Errorf("Channel %p did not close after %v, last state: %v", ch, timeout, ch.State()) // The introspected state might help debug why the channel isn't closing. ts.Logf("Introspected state:\n%s", IntrospectJSON(ch, &tchannel.IntrospectionOptions{ IncludeExchanges: true, IncludeTombstones: true, })) case <-ch.ClosedChan(): } } func (ts *TestServer) verify(ch *tchannel.Channel) { if ts.Failed() { return } // Tests may end with running background goroutines that are cleaning up, so give // them some time to finish before running verifications. var errs error WaitFor(time.Second, func() bool { errs = multierr.Combine( ts.verifyExchangesCleared(ch), ts.verifyRelaysEmpty(ch), ) return errs == nil }) if errs == nil { return } // If verification fails, get the marshalled state. assert.NoError(ts, errs, "Verification failed. Channel state:\n%v", IntrospectJSON(ch, nil /* opts */)) } // AddPostFn registers a function that will be executed after channels are closed. func (ts *TestServer) AddPostFn(fn func()) { ts.postFns = append(ts.postFns, fn) } func (ts *TestServer) post() { if !ts.Failed() { for _, ch := range ts.channels { ts.verifyNoStateLeak(ch) } } for _, fn := range ts.postFns { fn() } } func (ts *TestServer) verifyNoStateLeak(ch *tchannel.Channel) { initial := ts.channelStates[ch] final := comparableState(ch, ts.introspectOpts) assert.Equal(ts.TB, initial, final, "Runtime state has leaks") } func (ts *TestServer) verifyExchangesCleared(ch *tchannel.Channel) error { // Ensure that all the message exchanges are empty. serverState := ch.IntrospectState(ts.introspectOpts) if exchangesLeft := describeLeakedExchanges(serverState); exchangesLeft != "" { return fmt.Errorf("found uncleared message exchanges on %q:\n%v", ch.ServiceName(), exchangesLeft) } return nil } func (ts *TestServer) verifyRelaysEmpty(ch *tchannel.Channel) error { var errs error state := ch.IntrospectState(ts.introspectOpts) for _, peerState := range state.RootPeers { var connStates []tchannel.ConnectionRuntimeState connStates = append(connStates, peerState.InboundConnections...) connStates = append(connStates, peerState.OutboundConnections...) for _, connState := range connStates { n := connState.Relayer.Count if n != 0 { errs = multierr.Append(errs, fmt.Errorf("found %v left-over items in relayer for %v", n, connState.LocalHostPort)) } } } return errs } func (ts *TestServer) verifyNoGoroutinesLeaked() { if _leakedGoroutine.Load() { ts.Log("Skipping check for leaked goroutines because of a previous leak.") return } err := goroutines.IdentifyLeaks(ts.verifyOpts) if err == nil { // No leaks, nothing to do. return } if isFirstLeak := _leakedGoroutine.CAS(false, true); !isFirstLeak { ts.Log("Skipping check for leaked goroutines because of a previous leak.") return } if ts.Failed() { // If we've already failed this test, don't pollute the test output with // more failures. return } ts.Error(err.Error()) } func comparableState(ch *tchannel.Channel, opts *tchannel.IntrospectionOptions) *tchannel.RuntimeState { s := ch.IntrospectState(opts) s.SubChannels = nil s.Peers = nil // Tests start with ChannelClient or ChannelListening, but end with ChannelClosed. s.ChannelState = "" return s } func describeLeakedExchanges(rs *tchannel.RuntimeState) string { var connections []*tchannel.ConnectionRuntimeState for _, peer := range rs.RootPeers { for _, conn := range peer.InboundConnections { connections = append(connections, &conn) } for _, conn := range peer.OutboundConnections { connections = append(connections, &conn) } } return describeLeakedExchangesConns(connections) } func describeLeakedExchangesConns(connections []*tchannel.ConnectionRuntimeState) string { var exchanges []string for _, c := range connections { if exch := describeLeakedExchangesSingleConn(c); exch != "" { exchanges = append(exchanges, exch) } } return strings.Join(exchanges, "\n") } func describeLeakedExchangesSingleConn(cs *tchannel.ConnectionRuntimeState) string { var exchanges []string checkExchange := func(e tchannel.ExchangeSetRuntimeState) { if e.Count > 0 { exchanges = append(exchanges, fmt.Sprintf(" %v leftover %v exchanges", e.Name, e.Count)) for _, v := range e.Exchanges { exchanges = append(exchanges, fmt.Sprintf(" exchanges: %+v", v)) } } } checkExchange(cs.InboundExchange) checkExchange(cs.OutboundExchange) if len(exchanges) == 0 { return "" } return fmt.Sprintf("Connection %d has leftover exchanges:\n\t%v", cs.ID, strings.Join(exchanges, "\n\t")) } func withServer(t testing.TB, chanOpts *ChannelOpts, f func(testing.TB, *TestServer)) { ts := NewTestServer(t, chanOpts) // Note: We use defer, as we want the postFns to run even if the test // goroutine exits (e.g. user calls t.Fatalf). defer ts.post() defer ts.CloseAndVerify() f(t, ts) if ts.HasServer() { ts.Server().Logger().Debugf("TEST: Test function complete") } } ================================================ FILE: testutils/testreader/chunk.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testreader import ( "errors" "io" ) // ErrUser is returned by ChunkReader when the user requests an error. var ErrUser = errors.New("error set by user") // ChunkReader returns a reader that returns chunks written to the control channel. // The caller should write byte chunks to return to the channel, or write nil if they // want the Reader to return an error. The control channel should be closed to signal EOF. func ChunkReader() (chan<- []byte, io.Reader) { reader := &errorReader{ c: make(chan []byte, 100), } return reader.c, reader } type errorReader struct { c chan []byte remaining []byte } func (r *errorReader) Read(bs []byte) (int, error) { for len(r.remaining) == 0 { var ok bool r.remaining, ok = <-r.c if !ok { return 0, io.EOF } if r.remaining == nil { return 0, ErrUser } if len(r.remaining) == 0 { return 0, nil } } n := copy(bs, r.remaining) r.remaining = r.remaining[n:] return n, nil } ================================================ FILE: testutils/testreader/chunk_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testreader import ( "io" "io/ioutil" "testing" "github.com/stretchr/testify/assert" ) func TestChunkReader0ByteRead(t *testing.T) { writer, reader := ChunkReader() writer <- []byte{} writer <- []byte{'a'} close(writer) buf := make([]byte, 1) n, err := reader.Read(buf) assert.NoError(t, err, "Read should not fail") assert.Equal(t, 0, n, "Read should not read any bytes") n, err = reader.Read(buf) assert.NoError(t, err, "Read should not fail") assert.Equal(t, 1, n, "Read should read one byte") assert.EqualValues(t, 'a', buf[0], "Read did not read correct byte") n, err = reader.Read(buf) assert.Equal(t, io.EOF, err, "Read should EOF") assert.Equal(t, 0, n, "Read should not read any bytes") } func TestChunkReader(t *testing.T) { writer, reader := ChunkReader() writer <- []byte{1, 2} writer <- []byte{3} writer <- nil writer <- []byte{4} writer <- []byte{} writer <- []byte{5} writer <- []byte{} writer <- []byte{6} writer <- []byte{} close(writer) buf, err := ioutil.ReadAll(reader) assert.Equal(t, ErrUser, err, "Expected error after initial bytes") assert.Equal(t, []byte{1, 2, 3}, buf, "Unexpected bytes") buf, err = ioutil.ReadAll(reader) assert.NoError(t, err, "Reader shouldn't fail on second set of bytes") assert.Equal(t, []byte{4, 5, 6}, buf, "Unexpected bytes") } ================================================ FILE: testutils/testreader/loop.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testreader import "io" type loopReader struct { bs []byte pos int } func (r loopReader) Read(p []byte) (int, error) { for i := range p { p[i] = r.bs[r.pos] if r.pos++; r.pos == len(r.bs) { r.pos = 0 } } return len(p), nil } // Looper returns a reader that will return the bytes in bs as if it was a circular buffer. func Looper(bs []byte) io.Reader { return &loopReader{bs, 0} } ================================================ FILE: testutils/testreader/loop_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testreader import ( "testing" "github.com/stretchr/testify/assert" ) func TestLooper(t *testing.T) { tests := []struct { bs []byte expected []byte }{ {[]byte{0x1}, []byte{0x1, 0x1, 0x1}}, {[]byte{0x1, 0x2}, []byte{0x1, 0x2, 0x1}}, {[]byte{0x1, 0x2}, []byte{0x1, 0x2, 0x1, 0x2, 0x1, 0x2}}, } for _, tt := range tests { r := Looper(tt.bs) got := make([]byte, len(tt.expected)) n, err := r.Read(got) assert.NoError(t, err, "Read failed") assert.Equal(t, len(got), n) assert.Equal(t, tt.expected, got, "Got unexpected bytes") } } ================================================ FILE: testutils/testtracing/propagation.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testtracing import ( "fmt" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/jaeger-client-go" "golang.org/x/net/context" ) const ( // BaggageKey is used for testing baggage propagation BaggageKey = "luggage" // BaggageValue is used for testing baggage propagation BaggageValue = "suitcase" ) // TracingRequest tests tracing capabilities in a given server. type TracingRequest struct { // ForwardCount tells the server how many times to forward this request to itself recursively ForwardCount int } // TracingResponse captures the trace info observed in the server and its downstream calls type TracingResponse struct { TraceID uint64 SpanID uint64 ParentID uint64 TracingEnabled bool Child *TracingResponse Luggage string } // ObserveSpan extracts an OpenTracing span from the context and populates the response. func (r *TracingResponse) ObserveSpan(ctx context.Context) *TracingResponse { if span := opentracing.SpanFromContext(ctx); span != nil { if mockSpan, ok := span.(*mocktracer.MockSpan); ok { sc := mockSpan.Context().(mocktracer.MockSpanContext) r.TraceID = uint64(sc.TraceID) r.SpanID = uint64(sc.SpanID) r.ParentID = uint64(mockSpan.ParentID) r.TracingEnabled = sc.Sampled } else if span := tchannel.CurrentSpan(ctx); span != nil { r.TraceID = span.TraceID() r.SpanID = span.SpanID() r.ParentID = span.ParentID() r.TracingEnabled = span.Flags()&1 == 1 } r.Luggage = span.BaggageItem(BaggageKey) } return r } // TraceHandler is a base class for testing tracing propagation type TraceHandler struct { Ch *tchannel.Channel } // HandleCall is used by handlers from different encodings as the main business logic. // It respects the ForwardCount input parameter to make downstream calls, and returns // a result containing the observed tracing span and the downstream results. func (h *TraceHandler) HandleCall( ctx context.Context, req *TracingRequest, downstream TracingCall, ) (*TracingResponse, error) { var childResp *TracingResponse if req.ForwardCount > 0 { downstreamReq := &TracingRequest{ForwardCount: req.ForwardCount - 1} if resp, err := downstream(ctx, downstreamReq); err == nil { childResp = resp } else { return nil, err } } resp := &TracingResponse{Child: childResp} resp.ObserveSpan(ctx) return resp, nil } // TracerType is a convenient enum to indicate which type of tracer is being used in the test. // It is a string because it's printed as part of the test description in the logs. type TracerType string const ( // Noop is for the default no-op tracer from OpenTracing Noop TracerType = "NOOP" // Mock tracer, baggage-capable, non-Zipkin trace IDs Mock TracerType = "MOCK" // Jaeger is Uber's tracer, baggage-capable, Zipkin-style trace IDs Jaeger TracerType = "JAEGER" ) // TracingCall is used in a few other structs here type TracingCall func(ctx context.Context, req *TracingRequest) (*TracingResponse, error) // EncodingInfo describes the encoding used with tracing propagation test type EncodingInfo struct { Format tchannel.Format HeadersSupported bool } // PropagationTestSuite is a collection of test cases for a certain encoding type PropagationTestSuite struct { Encoding EncodingInfo Register func(t *testing.T, ch *tchannel.Channel) TracingCall TestCases map[TracerType][]PropagationTestCase } // PropagationTestCase describes a single propagation test case and expected results type PropagationTestCase struct { ForwardCount int TracingDisabled bool ExpectedBaggage string ExpectedSpanCount int } type tracerChoice struct { tracerType TracerType tracer opentracing.Tracer spansRecorded func() int resetSpans func() isFake bool zipkinCompatible bool } // Run executes the test cases in the test suite against 3 different tracer implementations func (s *PropagationTestSuite) Run(t *testing.T) { tests := []struct { name string run func(t *testing.T) }{ {"Noop_Tracer", s.runWithNoopTracer}, {"Mock_Tracer", s.runWithMockTracer}, {"Jaeger_Tracer", s.runWithJaegerTracer}, } for _, test := range tests { t.Logf("Running with %s", test.name) test.run(t) } } func (s *PropagationTestSuite) runWithNoopTracer(t *testing.T) { s.runWithTracer(t, tracerChoice{ tracer: nil, // will cause opentracing.GlobalTracer() to be used tracerType: Noop, spansRecorded: func() int { return 0 }, resetSpans: func() {}, isFake: true, }) } func (s *PropagationTestSuite) runWithMockTracer(t *testing.T) { mockTracer := mocktracer.New() s.runWithTracer(t, tracerChoice{ tracerType: Mock, tracer: mockTracer, spansRecorded: func() int { return len(MockTracerSampledSpans(mockTracer)) }, resetSpans: func() { mockTracer.Reset() }, }) } func (s *PropagationTestSuite) runWithJaegerTracer(t *testing.T) { jaegerReporter := jaeger.NewInMemoryReporter() jaegerTracer, jaegerCloser := jaeger.NewTracer(testutils.DefaultServerName, jaeger.NewConstSampler(true), jaegerReporter) // To enable logging, use composite reporter: // jaeger.NewCompositeReporter(jaegerReporter, jaeger.NewLoggingReporter(jaeger.StdLogger))) defer jaegerCloser.Close() s.runWithTracer(t, tracerChoice{ tracerType: Jaeger, tracer: jaegerTracer, spansRecorded: func() int { return len(jaegerReporter.GetSpans()) }, resetSpans: func() { jaegerReporter.Reset() }, zipkinCompatible: true, }) } func (s *PropagationTestSuite) runWithTracer(t *testing.T, tracer tracerChoice) { testCases, ok := s.TestCases[tracer.tracerType] if !ok { t.Logf("No test cases for encoding=%s and tracer=%s", s.Encoding.Format, tracer.tracerType) return } opts := &testutils.ChannelOpts{ ChannelOptions: tchannel.ChannelOptions{Tracer: tracer.tracer}, DisableRelay: true, } ch := testutils.NewServer(t, opts) defer ch.Close() ch.Peers().Add(ch.PeerInfo().HostPort) call := s.Register(t, ch) for _, tt := range testCases { s.runTestCase(t, tracer, ch, tt, call) } } func (s *PropagationTestSuite) runTestCase( t *testing.T, tracer tracerChoice, ch *tchannel.Channel, test PropagationTestCase, call TracingCall, ) { descr := fmt.Sprintf("test %+v with tracer %+v", test, tracer) ch.Logger().Debugf("Starting tracing test %s", descr) tracer.resetSpans() span := ch.Tracer().StartSpan("client") span.SetBaggageItem(BaggageKey, BaggageValue) ctx := opentracing.ContextWithSpan(context.Background(), span) ctxBuilder := tchannel.NewContextBuilder(5 * time.Second).SetParentContext(ctx) if test.TracingDisabled { ctxBuilder.DisableTracing() } ctx, cancel := ctxBuilder.Build() defer cancel() req := &TracingRequest{ForwardCount: test.ForwardCount} ch.Logger().Infof("Sending tracing request %+v", req) response, err := call(ctx, req) require.NoError(t, err) ch.Logger().Infof("Received tracing response %+v", response) // Spans are finished in inbound.doneSending() or outbound.doneReading(), // which are called on different go-routines and may execute *after* the // response has been received by the client. Give them a chance to run. for i := 0; i < 1000; i++ { if spanCount := tracer.spansRecorded(); spanCount == test.ExpectedSpanCount { break } time.Sleep(testutils.Timeout(time.Millisecond)) } spanCount := tracer.spansRecorded() ch.Logger().Debugf("end span count: %d", spanCount) // finish span after taking count of recorded spans, as we're only interested // in the count of spans created by RPC calls. span.Finish() root := new(TracingResponse).ObserveSpan(ctx) if !tracer.isFake { assert.Equal(t, uint64(0), root.ParentID) assert.NotEqual(t, uint64(0), root.TraceID) } assert.Equal(t, test.ExpectedSpanCount, spanCount, "Wrong span count; %s", descr) for r, cnt := response, 0; r != nil || cnt <= test.ForwardCount; r, cnt = r.Child, cnt+1 { require.NotNil(t, r, "Expecting response for forward=%d; %s", cnt, descr) if !tracer.isFake { if tracer.zipkinCompatible || s.Encoding.HeadersSupported { assert.Equal(t, root.TraceID, r.TraceID, "traceID should be the same; %s", descr) } assert.Equal(t, test.ExpectedBaggage, r.Luggage, "baggage should propagate; %s", descr) } } ch.Logger().Debugf("Finished tracing test %s", descr) } // MockTracerSampledSpans is a helper function that returns only sampled spans from MockTracer func MockTracerSampledSpans(tracer *mocktracer.MockTracer) []*mocktracer.MockSpan { var spans []*mocktracer.MockSpan for _, span := range tracer.FinishedSpans() { if span.Context().(mocktracer.MockSpanContext).Sampled { spans = append(spans, span) } } return spans } ================================================ FILE: testutils/testtracing/propagation_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testtracing import ( json_encoding "encoding/json" "testing" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "golang.org/x/net/context" ) func requestFromRaw(args *raw.Args) *TracingRequest { r := new(TracingRequest) r.ForwardCount = int(args.Arg3[0]) return r } func requestToRaw(r *TracingRequest) []byte { return []byte{byte(r.ForwardCount)} } func responseFromRaw(t *testing.T, arg3 []byte) (*TracingResponse, error) { var r TracingResponse err := json_encoding.Unmarshal(arg3, &r) if err != nil { return nil, err } return &r, nil } func responseToRaw(t *testing.T, r *TracingResponse) (*raw.Res, error) { jsonBytes, err := json_encoding.Marshal(r) if err != nil { return nil, err } return &raw.Res{Arg3: jsonBytes}, nil } // RawHandler tests tracing over Raw encoding type RawHandler struct { TraceHandler t *testing.T } func (h *RawHandler) Handle(ctx context.Context, args *raw.Args) (*raw.Res, error) { req := requestFromRaw(args) res, err := h.HandleCall(ctx, req, func(ctx context.Context, req *TracingRequest) (*TracingResponse, error) { _, arg3, _, err := raw.Call(ctx, h.Ch, h.Ch.PeerInfo().HostPort, h.Ch.PeerInfo().ServiceName, "rawcall", nil, requestToRaw(req)) if err != nil { return nil, err } return responseFromRaw(h.t, arg3) }) if err != nil { return nil, err } return responseToRaw(h.t, res) } func (h *RawHandler) OnError(ctx context.Context, err error) { h.t.Errorf("onError %v", err) } func (h *RawHandler) firstCall(ctx context.Context, req *TracingRequest) (*TracingResponse, error) { _, arg3, _, err := raw.Call(ctx, h.Ch, h.Ch.PeerInfo().HostPort, h.Ch.PeerInfo().ServiceName, "rawcall", nil, requestToRaw(req)) if err != nil { return nil, err } return responseFromRaw(h.t, arg3) } func TestRawTracingPropagation(t *testing.T) { suite := &PropagationTestSuite{ Encoding: EncodingInfo{Format: tchannel.Raw, HeadersSupported: false}, Register: func(t *testing.T, ch *tchannel.Channel) TracingCall { handler := &RawHandler{ TraceHandler: TraceHandler{Ch: ch}, t: t, } ch.Register(raw.Wrap(handler), "rawcall") return handler.firstCall }, // Since Raw encoding does not support headers, there is no baggage propagation TestCases: map[TracerType][]PropagationTestCase{ Noop: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: "", ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: "", ExpectedSpanCount: 0}, }, Mock: { // Since Raw encoding does not propagate generic traces, the tracingDisable // only affects the first outbound span (it's not sampled), but the other // two outbound spans are still sampled and recorded. {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: "", ExpectedSpanCount: 2}, // Since Raw encoding does not propagate generic traces, we record 3 spans // for outbound calls, but none for inbound calls. {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: "", ExpectedSpanCount: 3}, }, Jaeger: { // Since Jaeger is Zipkin-compatible, it is able to keep track of tracingDisabled {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: "", ExpectedSpanCount: 0}, // Since Jaeger is Zipkin-compatible, it is able to decode the trace // even from the Raw encoding. {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: "", ExpectedSpanCount: 6}, }, }, } suite.Run(t) } ================================================ FILE: testutils/testwriter/limited.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testwriter import ( "errors" "io" ) // ErrOutOfSpace is returned by Limited reader when it is out of bytes. var ErrOutOfSpace = errors.New("out of space") type writerFunc func([]byte) (int, error) func (f writerFunc) Write(p []byte) (n int, err error) { return f(p) } // Limited returns an io.Writer that will only accept n bytes. // All further calls will cause an error. func Limited(n int) io.Writer { return writerFunc(func(p []byte) (int, error) { if n < len(p) { retN := n n = 0 return retN, ErrOutOfSpace } n -= len(p) return len(p), nil }) } ================================================ FILE: testutils/testwriter/limited_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testwriter import ( "testing" "github.com/stretchr/testify/assert" ) func TestLimitedWriter(t *testing.T) { tests := []struct { limit int writeBytes []byte wantErr error wantBytes int }{ { limit: 1, writeBytes: []byte{1}, wantBytes: 1, }, { limit: 1, writeBytes: []byte{1, 2}, wantErr: ErrOutOfSpace, wantBytes: 1, }, { limit: 0, writeBytes: nil, wantBytes: 0, }, { limit: 5, writeBytes: []byte{1, 2, 3, 4, 5, 6}, wantErr: ErrOutOfSpace, wantBytes: 5, }, } for _, tt := range tests { writer := Limited(tt.limit) n, err := writer.Write(tt.writeBytes) if tt.wantErr != nil { assert.Equal(t, tt.wantErr, err, "Write %v to Limited(%v) should fail", tt.writeBytes, tt.limit) } else { assert.NoError(t, err, "Write %v to Limited(%v) should not fail", tt.writeBytes, tt.limit) } assert.Equal(t, tt.wantBytes, n, "Unexpected number of bytes written to Limited(%v)", tt.limit) n, err = writer.Write([]byte{2}) assert.Equal(t, ErrOutOfSpace, err, "Write should be out of space") assert.Equal(t, 0, n, "Write should not write any bytes when it is out of space") } } func TestLimitedWriter2(t *testing.T) { writer := Limited(1) n, err := writer.Write([]byte{1, 2}) assert.Equal(t, ErrOutOfSpace, err, "Write should fail") assert.Equal(t, 1, n, "Write should only write one byte") n, err = writer.Write([]byte{2}) assert.Equal(t, ErrOutOfSpace, err, "Write should be out of space") assert.Equal(t, 0, n, "Write should not write any bytes when it is out of space") } ================================================ FILE: testutils/thriftarg2test/arg2_kv.go ================================================ package thriftarg2test import ( "fmt" "testing" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/typed" ) // BuildKVBuffer builds an thrift Arg2 KV buffer. func BuildKVBuffer(kv map[string]string) []byte { // Scan once to know size of buffer var bufSize int for k, v := range kv { // k~2 v~2 bufSize += 2 + len(k) + 2 + len(v) } bufSize += 2 // nh:2 buf := make([]byte, bufSize) wb := typed.NewWriteBuffer(buf) wb.WriteUint16(uint16(len(kv))) for k, v := range kv { wb.WriteLen16String(k) wb.WriteLen16String(v) } return buf[:wb.BytesWritten()] } // ReadKVBuffer converts an arg2 buffer to a string map func ReadKVBuffer(b []byte) (map[string]string, error) { rbuf := typed.NewReadBuffer(b) nh := rbuf.ReadUint16() retMap := make(map[string]string, nh) for i := uint16(0); i < nh; i++ { key := rbuf.ReadLen16String() val := rbuf.ReadLen16String() retMap[key] = val } if rbuf.BytesRemaining() > 0 { return nil, fmt.Errorf("kv buffer wasn't fully consumed (%d bytes remaining)", rbuf.BytesRemaining()) } return retMap, nil } // MustReadKVBuffer calls require.NoError on the error returned by ReadKVBuffer func MustReadKVBuffer(tb testing.TB, b []byte) map[string]string { m, err := ReadKVBuffer(b) require.NoError(tb, err) return m } ================================================ FILE: testutils/thriftarg2test/arg2_kv_test.go ================================================ package thriftarg2test import ( "testing" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go/typed" ) func TestBuildKVBuffer(t *testing.T) { kv := map[string]string{ "key": "valval", "key2": "val", } buf := BuildKVBuffer(kv) rb := typed.NewReadBuffer(buf) assert.EqualValues(t, len(kv), rb.ReadUint16()) gotKV := make(map[string]string) for i := 0; i < len(kv); i++ { k := rb.ReadLen16String() v := rb.ReadLen16String() gotKV[k] = v } assert.Equal(t, kv, gotKV) } func TestReadKVBuffer(t *testing.T) { kvMap := map[string]string{ "key": "valval", "key2": "val", } var buffer [128]byte wbuf := typed.NewWriteBuffer(buffer[:]) wbuf.WriteUint16(uint16(len(kvMap))) // nh // the order doesn't matter here since we're comparing maps for k, v := range kvMap { wbuf.WriteLen16String(k) wbuf.WriteLen16String(v) } assert.Equal(t, kvMap, MustReadKVBuffer(t, buffer[:wbuf.BytesWritten()])) } ================================================ FILE: testutils/ticker.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import "time" // FakeTicker is a ticker for unit tests that can be controlled // deterministically. type FakeTicker struct { c chan time.Time } // NewFakeTicker returns a new instance of FakeTicker func NewFakeTicker() *FakeTicker { return &FakeTicker{ c: make(chan time.Time, 1), } } // Tick sends an immediate tick call to the receiver func (ft *FakeTicker) Tick() { ft.c <- time.Now() } // TryTick attempts to send a tick, if the channel isn't blocked. func (ft *FakeTicker) TryTick() bool { select { case ft.c <- time.Time{}: return true default: return false } } // New can be used in tests as a factory method for tickers, by passing it to // ChannelOptions.TimeTicker func (ft *FakeTicker) New(d time.Duration) *time.Ticker { t := time.NewTicker(time.Hour) t.C = ft.c return t } ================================================ FILE: testutils/ticker_test.go ================================================ // Copyright (c) 2017 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestFakeTicker(t *testing.T) { ft := NewFakeTicker() ticker := ft.New(time.Second) select { case <-ticker.C: t.Fatalf("Fake ticker ticked by itself") default: } for i := 0; i < 10; i++ { if i%2 == 0 { assert.True(t, ft.TryTick(), "TryTick should succeed with no other pending ticks") assert.False(t, ft.TryTick(), "TryTick should fail with pending ticks") } else { ft.Tick() } select { case <-ticker.C: default: t.Fatalf("Fake ticker tick did not unblock ticker") } } ticker.Stop() select { case <-ticker.C: t.Fatalf("Stopped ticker ticked") default: } } ================================================ FILE: testutils/timeout.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "fmt" "os" "runtime" "strconv" "strings" "testing" "time" ) var timeoutScaleFactor = 1.0 func init() { if v := os.Getenv("TEST_TIMEOUT_SCALE"); v != "" { fv, err := strconv.ParseFloat(v, 64) if err != nil { panic(err) } timeoutScaleFactor = fv fmt.Fprintln(os.Stderr, "Scaling timeouts by factor", timeoutScaleFactor) } } // Timeout returns the timeout multiplied by any set multiplier. func Timeout(timeout time.Duration) time.Duration { return time.Duration(timeoutScaleFactor * float64(timeout)) } // getCallerName returns the test name that called this function. // It traverses the stack to find the function name directly after a testing.* call. func getCallerName() string { pc := make([]uintptr, 10) n := runtime.Callers(2, pc) for i := n; i > 0; i-- { fname := runtime.FuncForPC(pc[i-1]).Name() if strings.HasPrefix(fname, "testing.") { return runtime.FuncForPC(pc[i-2]).Name() } } return "unknown" } // SetTimeout is used to fail tests after a timeout. It returns a function that should be // run once the test is complete. The standard way is to use defer, e.g. // defer SetTimeout(t, time.Second)() func SetTimeout(t *testing.T, timeout time.Duration) func() { timeout = Timeout(timeout) caller := getCallerName() timer := time.AfterFunc(timeout, func() { t.Logf("Test %s timed out after %v", caller, timeout) // Unfortunately, tests cannot be failed from new goroutines, so use a panic. panic(fmt.Errorf("Test %s timed out after %v", caller, timeout)) }) return func() { timer.Stop() } } ================================================ FILE: testutils/wait.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package testutils import ( "sync" "time" ) // WaitFor will retry f till it returns true for a maximum of timeout. // It returns true if f returned true, false if timeout was hit. func WaitFor(timeout time.Duration, f func() bool) bool { timeoutEnd := time.Now().Add(Timeout(timeout)) const maxSleep = time.Millisecond * 50 sleepFor := time.Millisecond for { if f() { return true } if time.Now().After(timeoutEnd) { return false } time.Sleep(sleepFor) if sleepFor < maxSleep { sleepFor *= 2 } } } // WaitWG waits for the given WaitGroup to be complete with a timeout // and returns whether the WaitGroup completed within the timeout. func WaitWG(wg *sync.WaitGroup, timeout time.Duration) bool { wgC := make(chan struct{}) go func() { wg.Wait() wgC <- struct{}{} }() select { case <-time.After(timeout): return false case <-wgC: return true } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/application_exception.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift const ( UNKNOWN_APPLICATION_EXCEPTION = 0 UNKNOWN_METHOD = 1 INVALID_MESSAGE_TYPE_EXCEPTION = 2 WRONG_METHOD_NAME = 3 BAD_SEQUENCE_ID = 4 MISSING_RESULT = 5 INTERNAL_ERROR = 6 PROTOCOL_ERROR = 7 ) // Application level Thrift exception type TApplicationException interface { TException TypeId() int32 Read(iprot TProtocol) (TApplicationException, error) Write(oprot TProtocol) error } type tApplicationException struct { message string type_ int32 } func (e tApplicationException) Error() string { return e.message } func NewTApplicationException(type_ int32, message string) TApplicationException { return &tApplicationException{message, type_} } func (p *tApplicationException) TypeId() int32 { return p.type_ } func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, error) { _, err := iprot.ReadStructBegin() if err != nil { return nil, err } message := "" type_ := int32(UNKNOWN_APPLICATION_EXCEPTION) for { _, ttype, id, err := iprot.ReadFieldBegin() if err != nil { return nil, err } if ttype == STOP { break } switch id { case 1: if ttype == STRING { if message, err = iprot.ReadString(); err != nil { return nil, err } } else { if err = SkipDefaultDepth(iprot, ttype); err != nil { return nil, err } } case 2: if ttype == I32 { if type_, err = iprot.ReadI32(); err != nil { return nil, err } } else { if err = SkipDefaultDepth(iprot, ttype); err != nil { return nil, err } } default: if err = SkipDefaultDepth(iprot, ttype); err != nil { return nil, err } } if err = iprot.ReadFieldEnd(); err != nil { return nil, err } } return NewTApplicationException(type_, message), iprot.ReadStructEnd() } func (p *tApplicationException) Write(oprot TProtocol) (err error) { err = oprot.WriteStructBegin("TApplicationException") if len(p.Error()) > 0 { err = oprot.WriteFieldBegin("message", STRING, 1) if err != nil { return } err = oprot.WriteString(p.Error()) if err != nil { return } err = oprot.WriteFieldEnd() if err != nil { return } } err = oprot.WriteFieldBegin("type", I32, 2) if err != nil { return } err = oprot.WriteI32(p.type_) if err != nil { return } err = oprot.WriteFieldEnd() if err != nil { return } err = oprot.WriteFieldStop() if err != nil { return } err = oprot.WriteStructEnd() return } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/application_exception_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "testing" ) func TestTApplicationException(t *testing.T) { exc := NewTApplicationException(UNKNOWN_APPLICATION_EXCEPTION, "") if exc.Error() != "" { t.Fatalf("Expected empty string for exception but found '%s'", exc.Error()) } if exc.TypeId() != UNKNOWN_APPLICATION_EXCEPTION { t.Fatalf("Expected type UNKNOWN for exception but found '%s'", exc.TypeId()) } exc = NewTApplicationException(WRONG_METHOD_NAME, "junk_method") if exc.Error() != "junk_method" { t.Fatalf("Expected 'junk_method' for exception but found '%s'", exc.Error()) } if exc.TypeId() != WRONG_METHOD_NAME { t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%s'", exc.TypeId()) } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/binary_protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "encoding/binary" "errors" "fmt" "io" "math" ) type TBinaryProtocol struct { trans TRichTransport origTransport TTransport reader io.Reader writer io.Writer strictRead bool strictWrite bool buffer [64]byte } type TBinaryProtocolFactory struct { strictRead bool strictWrite bool } func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol { return NewTBinaryProtocol(t, false, true) } func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol { p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite} if et, ok := t.(TRichTransport); ok { p.trans = et } else { p.trans = NewTRichTransport(t) } p.reader = p.trans p.writer = p.trans return p } func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory { return NewTBinaryProtocolFactory(false, true) } func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory { return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite} } func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol { return NewTBinaryProtocol(t, p.strictRead, p.strictWrite) } /** * Writing Methods */ func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { if p.strictWrite { version := uint32(VERSION_1) | uint32(typeId) e := p.WriteI32(int32(version)) if e != nil { return e } e = p.WriteString(name) if e != nil { return e } e = p.WriteI32(seqId) return e } else { e := p.WriteString(name) if e != nil { return e } e = p.WriteByte(int8(typeId)) if e != nil { return e } e = p.WriteI32(seqId) return e } return nil } func (p *TBinaryProtocol) WriteMessageEnd() error { return nil } func (p *TBinaryProtocol) WriteStructBegin(name string) error { return nil } func (p *TBinaryProtocol) WriteStructEnd() error { return nil } func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { e := p.WriteByte(int8(typeId)) if e != nil { return e } e = p.WriteI16(id) return e } func (p *TBinaryProtocol) WriteFieldEnd() error { return nil } func (p *TBinaryProtocol) WriteFieldStop() error { e := p.WriteByte(STOP) return e } func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { e := p.WriteByte(int8(keyType)) if e != nil { return e } e = p.WriteByte(int8(valueType)) if e != nil { return e } e = p.WriteI32(int32(size)) return e } func (p *TBinaryProtocol) WriteMapEnd() error { return nil } func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error { e := p.WriteByte(int8(elemType)) if e != nil { return e } e = p.WriteI32(int32(size)) return e } func (p *TBinaryProtocol) WriteListEnd() error { return nil } func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error { e := p.WriteByte(int8(elemType)) if e != nil { return e } e = p.WriteI32(int32(size)) return e } func (p *TBinaryProtocol) WriteSetEnd() error { return nil } func (p *TBinaryProtocol) WriteBool(value bool) error { if value { return p.WriteByte(1) } return p.WriteByte(0) } func (p *TBinaryProtocol) WriteByte(value int8) error { e := p.trans.WriteByte(byte(value)) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI16(value int16) error { v := p.buffer[0:2] binary.BigEndian.PutUint16(v, uint16(value)) _, e := p.writer.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI32(value int32) error { v := p.buffer[0:4] binary.BigEndian.PutUint32(v, uint32(value)) _, e := p.writer.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI64(value int64) error { v := p.buffer[0:8] binary.BigEndian.PutUint64(v, uint64(value)) _, err := p.writer.Write(v) return NewTProtocolException(err) } func (p *TBinaryProtocol) WriteDouble(value float64) error { return p.WriteI64(int64(math.Float64bits(value))) } func (p *TBinaryProtocol) WriteString(value string) error { e := p.WriteI32(int32(len(value))) if e != nil { return e } _, err := p.trans.WriteString(value) return NewTProtocolException(err) } func (p *TBinaryProtocol) WriteBinary(value []byte) error { e := p.WriteI32(int32(len(value))) if e != nil { return e } _, err := p.writer.Write(value) return NewTProtocolException(err) } /** * Reading methods */ func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { size, e := p.ReadI32() if e != nil { return "", typeId, 0, NewTProtocolException(e) } if size < 0 { typeId = TMessageType(size & 0x0ff) version := int64(int64(size) & VERSION_MASK) if version != VERSION_1 { return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin")) } name, e = p.ReadString() if e != nil { return name, typeId, seqId, NewTProtocolException(e) } seqId, e = p.ReadI32() if e != nil { return name, typeId, seqId, NewTProtocolException(e) } return name, typeId, seqId, nil } if p.strictRead { return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin")) } name, e2 := p.readStringBody(size) if e2 != nil { return name, typeId, seqId, e2 } b, e3 := p.ReadByte() if e3 != nil { return name, typeId, seqId, e3 } typeId = TMessageType(b) seqId, e4 := p.ReadI32() if e4 != nil { return name, typeId, seqId, e4 } return name, typeId, seqId, nil } func (p *TBinaryProtocol) ReadMessageEnd() error { return nil } func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) { return } func (p *TBinaryProtocol) ReadStructEnd() error { return nil } func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) { t, err := p.ReadByte() typeId = TType(t) if err != nil { return name, typeId, seqId, err } if t != STOP { seqId, err = p.ReadI16() } return name, typeId, seqId, err } func (p *TBinaryProtocol) ReadFieldEnd() error { return nil } var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { k, e := p.ReadByte() if e != nil { err = NewTProtocolException(e) return } kType = TType(k) v, e := p.ReadByte() if e != nil { err = NewTProtocolException(e) return } vType = TType(v) size32, e := p.ReadI32() if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) return kType, vType, size, nil } func (p *TBinaryProtocol) ReadMapEnd() error { return nil } func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { b, e := p.ReadByte() if e != nil { err = NewTProtocolException(e) return } elemType = TType(b) size32, e := p.ReadI32() if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) return } func (p *TBinaryProtocol) ReadListEnd() error { return nil } func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { b, e := p.ReadByte() if e != nil { err = NewTProtocolException(e) return } elemType = TType(b) size32, e := p.ReadI32() if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) return elemType, size, nil } func (p *TBinaryProtocol) ReadSetEnd() error { return nil } func (p *TBinaryProtocol) ReadBool() (bool, error) { b, e := p.ReadByte() v := true if b != 1 { v = false } return v, e } func (p *TBinaryProtocol) ReadByte() (int8, error) { v, err := p.trans.ReadByte() return int8(v), err } func (p *TBinaryProtocol) ReadI16() (value int16, err error) { buf := p.buffer[0:2] err = p.readAll(buf) value = int16(binary.BigEndian.Uint16(buf)) return value, err } func (p *TBinaryProtocol) ReadI32() (value int32, err error) { buf := p.buffer[0:4] err = p.readAll(buf) value = int32(binary.BigEndian.Uint32(buf)) return value, err } func (p *TBinaryProtocol) ReadI64() (value int64, err error) { buf := p.buffer[0:8] err = p.readAll(buf) value = int64(binary.BigEndian.Uint64(buf)) return value, err } func (p *TBinaryProtocol) ReadDouble() (value float64, err error) { buf := p.buffer[0:8] err = p.readAll(buf) value = math.Float64frombits(binary.BigEndian.Uint64(buf)) return value, err } func (p *TBinaryProtocol) ReadString() (value string, err error) { size, e := p.ReadI32() if e != nil { return "", e } if size < 0 { err = invalidDataLength return } return p.readStringBody(size) } func (p *TBinaryProtocol) ReadBinary() ([]byte, error) { size, e := p.ReadI32() if e != nil { return nil, e } if size < 0 { return nil, invalidDataLength } if uint64(size) > p.trans.RemainingBytes() { return nil, invalidDataLength } isize := int(size) buf := make([]byte, isize) _, err := io.ReadFull(p.trans, buf) return buf, NewTProtocolException(err) } func (p *TBinaryProtocol) Flush() (err error) { return NewTProtocolException(p.trans.Flush()) } func (p *TBinaryProtocol) Skip(fieldType TType) (err error) { return SkipDefaultDepth(p, fieldType) } func (p *TBinaryProtocol) Transport() TTransport { return p.origTransport } func (p *TBinaryProtocol) readAll(buf []byte) error { _, err := io.ReadFull(p.reader, buf) return NewTProtocolException(err) } const readLimit = 32768 func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) { if size < 0 { return "", nil } if uint64(size) > p.trans.RemainingBytes() { return "", invalidDataLength } var ( buf bytes.Buffer e error b []byte ) switch { case int(size) <= len(p.buffer): b = p.buffer[:size] // avoids allocation for small reads case int(size) < readLimit: b = make([]byte, size) default: b = make([]byte, readLimit) } for size > 0 { _, e = io.ReadFull(p.trans, b) buf.Write(b) if e != nil { break } size -= readLimit if size < readLimit && size > 0 { b = b[:size] } } return buf.String(), NewTProtocolException(e) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/binary_protocol_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "testing" ) func TestReadWriteBinaryProtocol(t *testing.T) { ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault()) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/buffered_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bufio" ) type TBufferedTransportFactory struct { size int } type TBufferedTransport struct { bufio.ReadWriter tp TTransport } func (p *TBufferedTransportFactory) GetTransport(trans TTransport) TTransport { return NewTBufferedTransport(trans, p.size) } func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory { return &TBufferedTransportFactory{size: bufferSize} } func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport { return &TBufferedTransport{ ReadWriter: bufio.ReadWriter{ Reader: bufio.NewReaderSize(trans, bufferSize), Writer: bufio.NewWriterSize(trans, bufferSize), }, tp: trans, } } func (p *TBufferedTransport) IsOpen() bool { return p.tp.IsOpen() } func (p *TBufferedTransport) Open() (err error) { return p.tp.Open() } func (p *TBufferedTransport) Close() (err error) { return p.tp.Close() } func (p *TBufferedTransport) Read(b []byte) (int, error) { n, err := p.ReadWriter.Read(b) if err != nil { p.ReadWriter.Reader.Reset(p.tp) } return n, err } func (p *TBufferedTransport) Write(b []byte) (int, error) { n, err := p.ReadWriter.Write(b) if err != nil { p.ReadWriter.Writer.Reset(p.tp) } return n, err } func (p *TBufferedTransport) Flush() error { if err := p.ReadWriter.Flush(); err != nil { p.ReadWriter.Writer.Reset(p.tp) return err } return p.tp.Flush() } func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) { return p.tp.RemainingBytes() } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/buffered_transport_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "testing" ) func TestBufferedTransport(t *testing.T) { trans := NewTBufferedTransport(NewTMemoryBuffer(), 10240) TransportTest(t, trans, trans) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/compact_protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "encoding/binary" "fmt" "io" "math" ) const ( COMPACT_PROTOCOL_ID = 0x082 COMPACT_VERSION = 1 COMPACT_VERSION_MASK = 0x1f COMPACT_TYPE_MASK = 0x0E0 COMPACT_TYPE_BITS = 0x07 COMPACT_TYPE_SHIFT_AMOUNT = 5 ) type tCompactType byte const ( COMPACT_BOOLEAN_TRUE = 0x01 COMPACT_BOOLEAN_FALSE = 0x02 COMPACT_BYTE = 0x03 COMPACT_I16 = 0x04 COMPACT_I32 = 0x05 COMPACT_I64 = 0x06 COMPACT_DOUBLE = 0x07 COMPACT_BINARY = 0x08 COMPACT_LIST = 0x09 COMPACT_SET = 0x0A COMPACT_MAP = 0x0B COMPACT_STRUCT = 0x0C ) var ( ttypeToCompactType map[TType]tCompactType ) func init() { ttypeToCompactType = map[TType]tCompactType{ STOP: STOP, BOOL: COMPACT_BOOLEAN_TRUE, BYTE: COMPACT_BYTE, I16: COMPACT_I16, I32: COMPACT_I32, I64: COMPACT_I64, DOUBLE: COMPACT_DOUBLE, STRING: COMPACT_BINARY, LIST: COMPACT_LIST, SET: COMPACT_SET, MAP: COMPACT_MAP, STRUCT: COMPACT_STRUCT, } } type TCompactProtocolFactory struct{} func NewTCompactProtocolFactory() *TCompactProtocolFactory { return &TCompactProtocolFactory{} } func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol { return NewTCompactProtocol(trans) } type TCompactProtocol struct { trans TRichTransport origTransport TTransport // Used to keep track of the last field for the current and previous structs, // so we can do the delta stuff. lastField []int lastFieldId int // If we encounter a boolean field begin, save the TField here so it can // have the value incorporated. booleanFieldName string booleanFieldId int16 booleanFieldPending bool // If we read a field header, and it's a boolean field, save the boolean // value here so that readBool can use it. boolValue bool boolValueIsNotNull bool buffer [64]byte } // Create a TCompactProtocol given a TTransport func NewTCompactProtocol(trans TTransport) *TCompactProtocol { p := &TCompactProtocol{origTransport: trans, lastField: []int{}} if et, ok := trans.(TRichTransport); ok { p.trans = et } else { p.trans = NewTRichTransport(trans) } return p } // // Public Writing methods. // // Write a message header to the wire. Compact Protocol messages contain the // protocol version so we can migrate forwards in the future if need be. func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { err := p.writeByteDirect(COMPACT_PROTOCOL_ID) if err != nil { return NewTProtocolException(err) } err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK)) if err != nil { return NewTProtocolException(err) } _, err = p.writeVarint32(seqid) if err != nil { return NewTProtocolException(err) } e := p.WriteString(name) return e } func (p *TCompactProtocol) WriteMessageEnd() error { return nil } // Write a struct begin. This doesn't actually put anything on the wire. We // use it as an opportunity to put special placeholder markers on the field // stack so we can get the field id deltas correct. func (p *TCompactProtocol) WriteStructBegin(name string) error { p.lastField = append(p.lastField, p.lastFieldId) p.lastFieldId = 0 return nil } // Write a struct end. This doesn't actually put anything on the wire. We use // this as an opportunity to pop the last field from the current struct off // of the field stack. func (p *TCompactProtocol) WriteStructEnd() error { p.lastFieldId = p.lastField[len(p.lastField)-1] p.lastField = p.lastField[:len(p.lastField)-1] return nil } func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { if typeId == BOOL { // we want to possibly include the value, so we'll wait. p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true return nil } _, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF) return NewTProtocolException(err) } // The workhorse of writeFieldBegin. It has the option of doing a // 'type override' of the type header. This is used specifically in the // boolean field case. func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) { // short lastField = lastField_.pop(); // if there's a type override, use that. var typeToWrite byte if typeOverride == 0xFF { typeToWrite = byte(p.getCompactType(typeId)) } else { typeToWrite = typeOverride } // check if we can use delta encoding for the field id fieldId := int(id) written := 0 if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 { // write them together err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite) if err != nil { return 0, err } } else { // write them separate err := p.writeByteDirect(typeToWrite) if err != nil { return 0, err } err = p.WriteI16(id) written = 1 + 2 if err != nil { return 0, err } } p.lastFieldId = fieldId // p.lastField.Push(field.id); return written, nil } func (p *TCompactProtocol) WriteFieldEnd() error { return nil } func (p *TCompactProtocol) WriteFieldStop() error { err := p.writeByteDirect(STOP) return NewTProtocolException(err) } func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { if size == 0 { err := p.writeByteDirect(0) return NewTProtocolException(err) } _, err := p.writeVarint32(int32(size)) if err != nil { return NewTProtocolException(err) } err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType))) return NewTProtocolException(err) } func (p *TCompactProtocol) WriteMapEnd() error { return nil } // Write a list header. func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error { _, err := p.writeCollectionBegin(elemType, size) return NewTProtocolException(err) } func (p *TCompactProtocol) WriteListEnd() error { return nil } // Write a set header. func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error { _, err := p.writeCollectionBegin(elemType, size) return NewTProtocolException(err) } func (p *TCompactProtocol) WriteSetEnd() error { return nil } func (p *TCompactProtocol) WriteBool(value bool) error { v := byte(COMPACT_BOOLEAN_FALSE) if value { v = byte(COMPACT_BOOLEAN_TRUE) } if p.booleanFieldPending { // we haven't written the field header yet _, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v) p.booleanFieldPending = false return NewTProtocolException(err) } // we're not part of a field, so just write the value. err := p.writeByteDirect(v) return NewTProtocolException(err) } // Write a byte. Nothing to see here! func (p *TCompactProtocol) WriteByte(value int8) error { err := p.writeByteDirect(byte(value)) return NewTProtocolException(err) } // Write an I16 as a zigzag varint. func (p *TCompactProtocol) WriteI16(value int16) error { _, err := p.writeVarint32(p.int32ToZigzag(int32(value))) return NewTProtocolException(err) } // Write an i32 as a zigzag varint. func (p *TCompactProtocol) WriteI32(value int32) error { _, err := p.writeVarint32(p.int32ToZigzag(value)) return NewTProtocolException(err) } // Write an i64 as a zigzag varint. func (p *TCompactProtocol) WriteI64(value int64) error { _, err := p.writeVarint64(p.int64ToZigzag(value)) return NewTProtocolException(err) } // Write a double to the wire as 8 bytes. func (p *TCompactProtocol) WriteDouble(value float64) error { buf := p.buffer[0:8] binary.LittleEndian.PutUint64(buf, math.Float64bits(value)) _, err := p.trans.Write(buf) return NewTProtocolException(err) } // Write a string to the wire with a varint size preceding. func (p *TCompactProtocol) WriteString(value string) error { _, e := p.writeVarint32(int32(len(value))) if e != nil { return NewTProtocolException(e) } if len(value) > 0 { } _, e = p.trans.WriteString(value) return e } // Write a byte array, using a varint for the size. func (p *TCompactProtocol) WriteBinary(bin []byte) error { _, e := p.writeVarint32(int32(len(bin))) if e != nil { return NewTProtocolException(e) } if len(bin) > 0 { _, e = p.trans.Write(bin) return NewTProtocolException(e) } return nil } // // Reading methods. // // Read a message header. func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { protocolId, err := p.readByteDirect() if err != nil { return } if protocolId != COMPACT_PROTOCOL_ID { e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId) return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e) } versionAndType, err := p.readByteDirect() if err != nil { return } version := versionAndType & COMPACT_VERSION_MASK typeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS) if version != COMPACT_VERSION { e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version) err = NewTProtocolExceptionWithType(BAD_VERSION, e) return } seqId, e := p.readVarint32() if e != nil { err = NewTProtocolException(e) return } name, err = p.ReadString() return } func (p *TCompactProtocol) ReadMessageEnd() error { return nil } // Read a struct begin. There's nothing on the wire for this, but it is our // opportunity to push a new struct begin marker onto the field stack. func (p *TCompactProtocol) ReadStructBegin() (name string, err error) { p.lastField = append(p.lastField, p.lastFieldId) p.lastFieldId = 0 return } // Doesn't actually consume any wire data, just removes the last field for // this struct from the field stack. func (p *TCompactProtocol) ReadStructEnd() error { // consume the last field we read off the wire. p.lastFieldId = p.lastField[len(p.lastField)-1] p.lastField = p.lastField[:len(p.lastField)-1] return nil } // Read a field header off the wire. func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { t, err := p.readByteDirect() if err != nil { return } // if it's a stop, then we can return immediately, as the struct is over. if (t & 0x0f) == STOP { return "", STOP, 0, nil } // mask off the 4 MSB of the type header. it could contain a field id delta. modifier := int16((t & 0xf0) >> 4) if modifier == 0 { // not a delta. look ahead for the zigzag varint field id. id, err = p.ReadI16() if err != nil { return } } else { // has a delta. add the delta to the last read field id. id = int16(p.lastFieldId) + modifier } typeId, e := p.getTType(tCompactType(t & 0x0f)) if e != nil { err = NewTProtocolException(e) return } // if this happens to be a boolean field, the value is encoded in the type if p.isBoolType(t) { // save the boolean value in a special instance variable. p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE) p.boolValueIsNotNull = true } // push the new field onto the field stack so we can keep the deltas going. p.lastFieldId = int(id) return } func (p *TCompactProtocol) ReadFieldEnd() error { return nil } // Read a map header off the wire. If the size is zero, skip reading the key // and value type. This means that 0-length maps will yield TMaps without the // "correct" types. func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { size32, e := p.readVarint32() if e != nil { err = NewTProtocolException(e) return } if size32 < 0 { err = invalidDataLength return } size = int(size32) keyAndValueType := byte(STOP) if size != 0 { keyAndValueType, err = p.readByteDirect() if err != nil { return } } keyType, _ = p.getTType(tCompactType(keyAndValueType >> 4)) valueType, _ = p.getTType(tCompactType(keyAndValueType & 0xf)) return } func (p *TCompactProtocol) ReadMapEnd() error { return nil } // Read a list header off the wire. If the list size is 0-14, the size will // be packed into the element type header. If it's a longer list, the 4 MSB // of the element type header will be 0xF, and a varint will follow with the // true size. func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) { size_and_type, err := p.readByteDirect() if err != nil { return } size = int((size_and_type >> 4) & 0x0f) if size == 15 { size2, e := p.readVarint32() if e != nil { err = NewTProtocolException(e) return } if size2 < 0 { err = invalidDataLength return } size = int(size2) } elemType, e := p.getTType(tCompactType(size_and_type)) if e != nil { err = NewTProtocolException(e) return } return } func (p *TCompactProtocol) ReadListEnd() error { return nil } // Read a set header off the wire. If the set size is 0-14, the size will // be packed into the element type header. If it's a longer set, the 4 MSB // of the element type header will be 0xF, and a varint will follow with the // true size. func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) { return p.ReadListBegin() } func (p *TCompactProtocol) ReadSetEnd() error { return nil } // Read a boolean off the wire. If this is a boolean field, the value should // already have been read during readFieldBegin, so we'll just consume the // pre-stored value. Otherwise, read a byte. func (p *TCompactProtocol) ReadBool() (value bool, err error) { if p.boolValueIsNotNull { p.boolValueIsNotNull = false return p.boolValue, nil } v, err := p.readByteDirect() return v == COMPACT_BOOLEAN_TRUE, err } // Read a single byte off the wire. Nothing interesting here. func (p *TCompactProtocol) ReadByte() (int8, error) { v, err := p.readByteDirect() if err != nil { return 0, NewTProtocolException(err) } return int8(v), err } // Read an i16 from the wire as a zigzag varint. func (p *TCompactProtocol) ReadI16() (value int16, err error) { v, err := p.ReadI32() return int16(v), err } // Read an i32 from the wire as a zigzag varint. func (p *TCompactProtocol) ReadI32() (value int32, err error) { v, e := p.readVarint32() if e != nil { return 0, NewTProtocolException(e) } value = p.zigzagToInt32(v) return value, nil } // Read an i64 from the wire as a zigzag varint. func (p *TCompactProtocol) ReadI64() (value int64, err error) { v, e := p.readVarint64() if e != nil { return 0, NewTProtocolException(e) } value = p.zigzagToInt64(v) return value, nil } // No magic here - just read a double off the wire. func (p *TCompactProtocol) ReadDouble() (value float64, err error) { longBits := p.buffer[0:8] _, e := io.ReadFull(p.trans, longBits) if e != nil { return 0.0, NewTProtocolException(e) } return math.Float64frombits(p.bytesToUint64(longBits)), nil } // Reads a []byte (via readBinary), and then UTF-8 decodes it. func (p *TCompactProtocol) ReadString() (value string, err error) { length, e := p.readVarint32() if e != nil { return "", NewTProtocolException(e) } if length < 0 { return "", invalidDataLength } if uint64(length) > p.trans.RemainingBytes() { return "", invalidDataLength } if length == 0 { return "", nil } var buf []byte if length <= int32(len(p.buffer)) { buf = p.buffer[0:length] } else { buf = make([]byte, length) } _, e = io.ReadFull(p.trans, buf) return string(buf), NewTProtocolException(e) } // Read a []byte from the wire. func (p *TCompactProtocol) ReadBinary() (value []byte, err error) { length, e := p.readVarint32() if e != nil { return nil, NewTProtocolException(e) } if length == 0 { return []byte{}, nil } if length < 0 { return nil, invalidDataLength } if uint64(length) > p.trans.RemainingBytes() { return nil, invalidDataLength } buf := make([]byte, length) _, e = io.ReadFull(p.trans, buf) return buf, NewTProtocolException(e) } func (p *TCompactProtocol) Flush() (err error) { return NewTProtocolException(p.trans.Flush()) } func (p *TCompactProtocol) Skip(fieldType TType) (err error) { return SkipDefaultDepth(p, fieldType) } func (p *TCompactProtocol) Transport() TTransport { return p.origTransport } // // Internal writing methods // // Abstract method for writing the start of lists and sets. List and sets on // the wire differ only by the type indicator. func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) { if size <= 14 { return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType)))) } err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType))) if err != nil { return 0, err } m, err := p.writeVarint32(int32(size)) return 1 + m, err } // Write an i32 as a varint. Results in 1-5 bytes on the wire. // TODO(pomack): make a permanent buffer like writeVarint64? func (p *TCompactProtocol) writeVarint32(n int32) (int, error) { i32buf := p.buffer[0:5] idx := 0 for { if (n & ^0x7F) == 0 { i32buf[idx] = byte(n) idx++ // p.writeByteDirect(byte(n)); break // return; } else { i32buf[idx] = byte((n & 0x7F) | 0x80) idx++ // p.writeByteDirect(byte(((n & 0x7F) | 0x80))); u := uint32(n) n = int32(u >> 7) } } return p.trans.Write(i32buf[0:idx]) } // Write an i64 as a varint. Results in 1-10 bytes on the wire. func (p *TCompactProtocol) writeVarint64(n int64) (int, error) { varint64out := p.buffer[0:10] idx := 0 for { if (n & ^0x7F) == 0 { varint64out[idx] = byte(n) idx++ break } else { varint64out[idx] = byte((n & 0x7F) | 0x80) idx++ u := uint64(n) n = int64(u >> 7) } } return p.trans.Write(varint64out[0:idx]) } // Convert l into a zigzag long. This allows negative numbers to be // represented compactly as a varint. func (p *TCompactProtocol) int64ToZigzag(l int64) int64 { return (l << 1) ^ (l >> 63) } // Convert l into a zigzag long. This allows negative numbers to be // represented compactly as a varint. func (p *TCompactProtocol) int32ToZigzag(n int32) int32 { return (n << 1) ^ (n >> 31) } func (p *TCompactProtocol) fixedUint64ToBytes(n uint64, buf []byte) { binary.LittleEndian.PutUint64(buf, n) } func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) { binary.LittleEndian.PutUint64(buf, uint64(n)) } // Writes a byte without any possibility of all that field header nonsense. // Used internally by other writing methods that know they need to write a byte. func (p *TCompactProtocol) writeByteDirect(b byte) error { return p.trans.WriteByte(b) } // Writes a byte without any possibility of all that field header nonsense. func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) { return 1, p.writeByteDirect(byte(n)) } // // Internal reading methods // // Read an i32 from the wire as a varint. The MSB of each byte is set // if there is another byte to follow. This can read up to 5 bytes. func (p *TCompactProtocol) readVarint32() (int32, error) { // if the wire contains the right stuff, this will just truncate the i64 we // read and get us the right sign. v, err := p.readVarint64() return int32(v), err } // Read an i64 from the wire as a proper varint. The MSB of each byte is set // if there is another byte to follow. This can read up to 10 bytes. func (p *TCompactProtocol) readVarint64() (int64, error) { shift := uint(0) result := int64(0) for { b, err := p.readByteDirect() if err != nil { return 0, err } result |= int64(b&0x7f) << shift if (b & 0x80) != 0x80 { break } shift += 7 } return result, nil } // Read a byte, unlike ReadByte that reads Thrift-byte that is i8. func (p *TCompactProtocol) readByteDirect() (byte, error) { return p.trans.ReadByte() } // // encoding helpers // // Convert from zigzag int to int. func (p *TCompactProtocol) zigzagToInt32(n int32) int32 { u := uint32(n) return int32(u>>1) ^ -(n & 1) } // Convert from zigzag long to long. func (p *TCompactProtocol) zigzagToInt64(n int64) int64 { u := uint64(n) return int64(u>>1) ^ -(n & 1) } // Note that it's important that the mask bytes are long literals, // otherwise they'll default to ints, and when you shift an int left 56 bits, // you just get a messed up int. func (p *TCompactProtocol) bytesToInt64(b []byte) int64 { return int64(binary.LittleEndian.Uint64(b)) } // Note that it's important that the mask bytes are long literals, // otherwise they'll default to ints, and when you shift an int left 56 bits, // you just get a messed up int. func (p *TCompactProtocol) bytesToUint64(b []byte) uint64 { return binary.LittleEndian.Uint64(b) } // // type testing and converting // func (p *TCompactProtocol) isBoolType(b byte) bool { return (b&0x0f) == COMPACT_BOOLEAN_TRUE || (b&0x0f) == COMPACT_BOOLEAN_FALSE } // Given a tCompactType constant, convert it to its corresponding // TType value. func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) { switch byte(t) & 0x0f { case STOP: return STOP, nil case COMPACT_BOOLEAN_FALSE, COMPACT_BOOLEAN_TRUE: return BOOL, nil case COMPACT_BYTE: return BYTE, nil case COMPACT_I16: return I16, nil case COMPACT_I32: return I32, nil case COMPACT_I64: return I64, nil case COMPACT_DOUBLE: return DOUBLE, nil case COMPACT_BINARY: return STRING, nil case COMPACT_LIST: return LIST, nil case COMPACT_SET: return SET, nil case COMPACT_MAP: return MAP, nil case COMPACT_STRUCT: return STRUCT, nil } return STOP, TException(fmt.Errorf("don't know what type: %s", t&0x0f)) } // Given a TType value, find the appropriate TCompactProtocol.Types constant. func (p *TCompactProtocol) getCompactType(t TType) tCompactType { return ttypeToCompactType[t] } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/compact_protocol_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "testing" ) func TestReadWriteCompactProtocol(t *testing.T) { ReadWriteProtocolTest(t, NewTCompactProtocolFactory()) transports := []TTransport{ NewTMemoryBuffer(), NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))), NewTFramedTransport(NewTMemoryBuffer()), } for _, trans := range transports { p := NewTCompactProtocol(trans); ReadWriteBool(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteByte(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteI16(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteI32(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteI64(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteDouble(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteString(t, p, trans); p = NewTCompactProtocol(trans); ReadWriteBinary(t, p, trans); trans.Close(); } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/debug_protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "log" ) type TDebugProtocol struct { Delegate TProtocol LogPrefix string } type TDebugProtocolFactory struct { Underlying TProtocolFactory LogPrefix string } func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory { return &TDebugProtocolFactory{ Underlying: underlying, LogPrefix: logPrefix, } } func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol { return &TDebugProtocol{ Delegate: t.Underlying.GetProtocol(trans), LogPrefix: t.LogPrefix, } } func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid) log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err) return err } func (tdp *TDebugProtocol) WriteMessageEnd() error { err := tdp.Delegate.WriteMessageEnd() log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteStructBegin(name string) error { err := tdp.Delegate.WriteStructBegin(name) log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err) return err } func (tdp *TDebugProtocol) WriteStructEnd() error { err := tdp.Delegate.WriteStructEnd() log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { err := tdp.Delegate.WriteFieldBegin(name, typeId, id) log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err) return err } func (tdp *TDebugProtocol) WriteFieldEnd() error { err := tdp.Delegate.WriteFieldEnd() log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteFieldStop() error { err := tdp.Delegate.WriteFieldStop() log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { err := tdp.Delegate.WriteMapBegin(keyType, valueType, size) log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err) return err } func (tdp *TDebugProtocol) WriteMapEnd() error { err := tdp.Delegate.WriteMapEnd() log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error { err := tdp.Delegate.WriteListBegin(elemType, size) log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) return err } func (tdp *TDebugProtocol) WriteListEnd() error { err := tdp.Delegate.WriteListEnd() log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error { err := tdp.Delegate.WriteSetBegin(elemType, size) log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) return err } func (tdp *TDebugProtocol) WriteSetEnd() error { err := tdp.Delegate.WriteSetEnd() log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err) return err } func (tdp *TDebugProtocol) WriteBool(value bool) error { err := tdp.Delegate.WriteBool(value) log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteByte(value int8) error { err := tdp.Delegate.WriteByte(value) log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteI16(value int16) error { err := tdp.Delegate.WriteI16(value) log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteI32(value int32) error { err := tdp.Delegate.WriteI32(value) log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteI64(value int64) error { err := tdp.Delegate.WriteI64(value) log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteDouble(value float64) error { err := tdp.Delegate.WriteDouble(value) log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteString(value string) error { err := tdp.Delegate.WriteString(value) log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) WriteBinary(value []byte) error { err := tdp.Delegate.WriteBinary(value) log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin() log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err) return } func (tdp *TDebugProtocol) ReadMessageEnd() (err error) { err = tdp.Delegate.ReadMessageEnd() log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) { name, err = tdp.Delegate.ReadStructBegin() log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err) return } func (tdp *TDebugProtocol) ReadStructEnd() (err error) { err = tdp.Delegate.ReadStructEnd() log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { name, typeId, id, err = tdp.Delegate.ReadFieldBegin() log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err) return } func (tdp *TDebugProtocol) ReadFieldEnd() (err error) { err = tdp.Delegate.ReadFieldEnd() log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { keyType, valueType, size, err = tdp.Delegate.ReadMapBegin() log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err) return } func (tdp *TDebugProtocol) ReadMapEnd() (err error) { err = tdp.Delegate.ReadMapEnd() log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) { elemType, size, err = tdp.Delegate.ReadListBegin() log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) return } func (tdp *TDebugProtocol) ReadListEnd() (err error) { err = tdp.Delegate.ReadListEnd() log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) { elemType, size, err = tdp.Delegate.ReadSetBegin() log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) return } func (tdp *TDebugProtocol) ReadSetEnd() (err error) { err = tdp.Delegate.ReadSetEnd() log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) ReadBool() (value bool, err error) { value, err = tdp.Delegate.ReadBool() log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadByte() (value int8, err error) { value, err = tdp.Delegate.ReadByte() log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadI16() (value int16, err error) { value, err = tdp.Delegate.ReadI16() log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadI32() (value int32, err error) { value, err = tdp.Delegate.ReadI32() log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadI64() (value int64, err error) { value, err = tdp.Delegate.ReadI64() log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) { value, err = tdp.Delegate.ReadDouble() log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadString() (value string, err error) { value, err = tdp.Delegate.ReadString() log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) { value, err = tdp.Delegate.ReadBinary() log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) { err = tdp.Delegate.Skip(fieldType) log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err) return } func (tdp *TDebugProtocol) Flush() (err error) { err = tdp.Delegate.Flush() log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err) return } func (tdp *TDebugProtocol) Transport() TTransport { return tdp.Delegate.Transport() } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/deserializer.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift type TDeserializer struct { Transport TTransport Protocol TProtocol } func NewTDeserializer() *TDeserializer { var transport TTransport transport = NewTMemoryBufferLen(1024) protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport) return &TDeserializer{ transport, protocol} } func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) { err = nil if _, err = t.Transport.Write([]byte(s)); err != nil { return } if err = msg.Read(t.Protocol); err != nil { return } return } func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) { err = nil if _, err = t.Transport.Write(b); err != nil { return } if err = msg.Read(t.Protocol); err != nil { return } return } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/exception.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "errors" ) // Generic Thrift exception type TException interface { error } // Prepends additional information to an error without losing the Thrift exception interface func PrependError(prepend string, err error) error { if t, ok := err.(TTransportException); ok { return NewTTransportException(t.TypeId(), prepend+t.Error()) } if t, ok := err.(TProtocolException); ok { return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error())) } if t, ok := err.(TApplicationException); ok { return NewTApplicationException(t.TypeId(), prepend+t.Error()) } return errors.New(prepend + err.Error()) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/exception_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "errors" "testing" ) func TestPrependError(t *testing.T) { err := NewTApplicationException(INTERNAL_ERROR, "original error") err2, ok := PrependError("Prepend: ", err).(TApplicationException) if !ok { t.Fatal("Couldn't cast error TApplicationException") } if err2.Error() != "Prepend: original error" { t.Fatal("Unexpected error string") } if err2.TypeId() != INTERNAL_ERROR { t.Fatal("Unexpected type error") } err3 := NewTProtocolExceptionWithType(INVALID_DATA, errors.New("original error")) err4, ok := PrependError("Prepend: ", err3).(TProtocolException) if !ok { t.Fatal("Couldn't cast error TProtocolException") } if err4.Error() != "Prepend: original error" { t.Fatal("Unexpected error string") } if err4.TypeId() != INVALID_DATA { t.Fatal("Unexpected type error") } err5 := NewTTransportException(TIMED_OUT, "original error") err6, ok := PrependError("Prepend: ", err5).(TTransportException) if !ok { t.Fatal("Couldn't cast error TTransportException") } if err6.Error() != "Prepend: original error" { t.Fatal("Unexpected error string") } if err6.TypeId() != TIMED_OUT { t.Fatal("Unexpected type error") } err7 := errors.New("original error") err8 := PrependError("Prepend: ", err7) if err8.Error() != "Prepend: original error" { t.Fatal("Unexpected error string") } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/field.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Helper class that encapsulates field metadata. type field struct { name string typeId TType id int } func newField(n string, t TType, i int) *field { return &field{name: n, typeId: t, id: i} } func (p *field) Name() string { if p == nil { return "" } return p.name } func (p *field) TypeId() TType { if p == nil { return TType(VOID) } return p.typeId } func (p *field) Id() int { if p == nil { return -1 } return p.id } func (p *field) String() string { if p == nil { return "" } return "" } var ANONYMOUS_FIELD *field type fieldSlice []field func (p fieldSlice) Len() int { return len(p) } func (p fieldSlice) Less(i, j int) bool { return p[i].Id() < p[j].Id() } func (p fieldSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } func init() { ANONYMOUS_FIELD = newField("", STOP, 0) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/framed_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bufio" "bytes" "encoding/binary" "fmt" "io" ) const DEFAULT_MAX_LENGTH = 16384000 type TFramedTransport struct { transport TTransport buf bytes.Buffer reader *bufio.Reader frameSize uint32 //Current remaining size of the frame. if ==0 read next frame header buffer [4]byte maxLength uint32 } type tFramedTransportFactory struct { factory TTransportFactory maxLength uint32 } func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory { return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH} } func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory { return &tFramedTransportFactory{factory: factory, maxLength: maxLength} } func (p *tFramedTransportFactory) GetTransport(base TTransport) TTransport { return NewTFramedTransportMaxLength(p.factory.GetTransport(base), p.maxLength) } func NewTFramedTransport(transport TTransport) *TFramedTransport { return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH} } func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport { return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength} } func (p *TFramedTransport) Open() error { return p.transport.Open() } func (p *TFramedTransport) IsOpen() bool { return p.transport.IsOpen() } func (p *TFramedTransport) Close() error { return p.transport.Close() } func (p *TFramedTransport) Read(buf []byte) (l int, err error) { if p.frameSize == 0 { p.frameSize, err = p.readFrameHeader() if err != nil { return } } if p.frameSize < uint32(len(buf)) { frameSize := p.frameSize tmp := make([]byte, p.frameSize) l, err = p.Read(tmp) copy(buf, tmp) if err == nil { err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf))) return } } got, err := p.reader.Read(buf) p.frameSize = p.frameSize - uint32(got) //sanity check if p.frameSize < 0 { return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size") } return got, NewTTransportExceptionFromError(err) } func (p *TFramedTransport) ReadByte() (c byte, err error) { if p.frameSize == 0 { p.frameSize, err = p.readFrameHeader() if err != nil { return } } if p.frameSize < 1 { return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", p.frameSize, 1)) } c, err = p.reader.ReadByte() if err == nil { p.frameSize-- } return } func (p *TFramedTransport) Write(buf []byte) (int, error) { n, err := p.buf.Write(buf) return n, NewTTransportExceptionFromError(err) } func (p *TFramedTransport) WriteByte(c byte) error { return p.buf.WriteByte(c) } func (p *TFramedTransport) WriteString(s string) (n int, err error) { return p.buf.WriteString(s) } func (p *TFramedTransport) Flush() error { size := p.buf.Len() buf := p.buffer[:4] binary.BigEndian.PutUint32(buf, uint32(size)) _, err := p.transport.Write(buf) if err != nil { return NewTTransportExceptionFromError(err) } if size > 0 { if n, err := p.buf.WriteTo(p.transport); err != nil { print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n") return NewTTransportExceptionFromError(err) } } err = p.transport.Flush() return NewTTransportExceptionFromError(err) } func (p *TFramedTransport) readFrameHeader() (uint32, error) { buf := p.buffer[:4] if _, err := io.ReadFull(p.reader, buf); err != nil { return 0, err } size := binary.BigEndian.Uint32(buf) if size < 0 || size > p.maxLength { return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size)) } return size, nil } func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) { return uint64(p.frameSize) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/framed_transport_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "testing" ) func TestFramedTransport(t *testing.T) { trans := NewTFramedTransport(NewTMemoryBuffer()) TransportTest(t, trans, trans) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/http_client.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "io" "io/ioutil" "net/http" "net/url" "strconv" ) // Default to using the shared http client. Library users are // free to change this global client or specify one through // THttpClientOptions. var DefaultHttpClient *http.Client = http.DefaultClient type THttpClient struct { client *http.Client response *http.Response url *url.URL requestBuffer *bytes.Buffer header http.Header nsecConnectTimeout int64 nsecReadTimeout int64 } type THttpClientTransportFactory struct { options THttpClientOptions url string isPost bool } func (p *THttpClientTransportFactory) GetTransport(trans TTransport) TTransport { if trans != nil { t, ok := trans.(*THttpClient) if ok && t.url != nil { if t.requestBuffer != nil { t2, _ := NewTHttpPostClientWithOptions(t.url.String(), p.options) return t2 } t2, _ := NewTHttpClientWithOptions(t.url.String(), p.options) return t2 } } if p.isPost { s, _ := NewTHttpPostClientWithOptions(p.url, p.options) return s } s, _ := NewTHttpClientWithOptions(p.url, p.options) return s } type THttpClientOptions struct { // If nil, DefaultHttpClient is used Client *http.Client } func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory { return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{}) } func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory { return &THttpClientTransportFactory{url: url, isPost: false, options: options} } func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory { return NewTHttpPostClientTransportFactoryWithOptions(url, THttpClientOptions{}) } func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory { return &THttpClientTransportFactory{url: url, isPost: true, options: options} } func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) { parsedURL, err := url.Parse(urlstr) if err != nil { return nil, err } response, err := http.Get(urlstr) if err != nil { return nil, err } client := options.Client if client == nil { client = DefaultHttpClient } httpHeader := map[string][]string{"Content-Type": []string{"application/x-thrift"}} return &THttpClient{client: client, response: response, url: parsedURL, header: httpHeader}, nil } func NewTHttpClient(urlstr string) (TTransport, error) { return NewTHttpClientWithOptions(urlstr, THttpClientOptions{}) } func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) { parsedURL, err := url.Parse(urlstr) if err != nil { return nil, err } buf := make([]byte, 0, 1024) client := options.Client if client == nil { client = DefaultHttpClient } httpHeader := map[string][]string{"Content-Type": []string{"application/x-thrift"}} return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil } func NewTHttpPostClient(urlstr string) (TTransport, error) { return NewTHttpPostClientWithOptions(urlstr, THttpClientOptions{}) } // Set the HTTP Header for this specific Thrift Transport // It is important that you first assert the TTransport as a THttpClient type // like so: // // httpTrans := trans.(THttpClient) // httpTrans.SetHeader("User-Agent","Thrift Client 1.0") func (p *THttpClient) SetHeader(key string, value string) { p.header.Add(key, value) } // Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport // It is important that you first assert the TTransport as a THttpClient type // like so: // // httpTrans := trans.(THttpClient) // hdrValue := httpTrans.GetHeader("User-Agent") func (p *THttpClient) GetHeader(key string) string { return p.header.Get(key) } // Deletes the HTTP Header given a Header Key for this specific Thrift Transport // It is important that you first assert the TTransport as a THttpClient type // like so: // // httpTrans := trans.(THttpClient) // httpTrans.DelHeader("User-Agent") func (p *THttpClient) DelHeader(key string) { p.header.Del(key) } func (p *THttpClient) Open() error { // do nothing return nil } func (p *THttpClient) IsOpen() bool { return p.response != nil || p.requestBuffer != nil } func (p *THttpClient) closeResponse() error { var err error if p.response != nil && p.response.Body != nil { // The docs specify that if keepalive is enabled and the response body is not // read to completion the connection will never be returned to the pool and // reused. Errors are being ignored here because if the connection is invalid // and this fails for some reason, the Close() method will do any remaining // cleanup. io.Copy(ioutil.Discard, p.response.Body) err = p.response.Body.Close() } p.response = nil return err } func (p *THttpClient) Close() error { if p.requestBuffer != nil { p.requestBuffer.Reset() p.requestBuffer = nil } return p.closeResponse() } func (p *THttpClient) Read(buf []byte) (int, error) { if p.response == nil { return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.") } n, err := p.response.Body.Read(buf) if n > 0 && (err == nil || err == io.EOF) { return n, nil } return n, NewTTransportExceptionFromError(err) } func (p *THttpClient) ReadByte() (c byte, err error) { return readByte(p.response.Body) } func (p *THttpClient) Write(buf []byte) (int, error) { n, err := p.requestBuffer.Write(buf) return n, err } func (p *THttpClient) WriteByte(c byte) error { return p.requestBuffer.WriteByte(c) } func (p *THttpClient) WriteString(s string) (n int, err error) { return p.requestBuffer.WriteString(s) } func (p *THttpClient) Flush() error { // Close any previous response body to avoid leaking connections. p.closeResponse() req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) if err != nil { return NewTTransportExceptionFromError(err) } req.Header = p.header response, err := p.client.Do(req) if err != nil { return NewTTransportExceptionFromError(err) } if response.StatusCode != http.StatusOK { // Close the response to avoid leaking file descriptors. closeResponse does // more than just call Close(), so temporarily assign it and reuse the logic. p.response = response p.closeResponse() // TODO(pomack) log bad response return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode)) } p.response = response return nil } func (p *THttpClient) RemainingBytes() (num_bytes uint64) { len := p.response.ContentLength if len >= 0 { return uint64(len) } const maxSize = ^uint64(0) return maxSize // the thruth is, we just don't know unless framed is used } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/http_client_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "net/http" "testing" ) func TestHttpClient(t *testing.T) { l, addr := HttpClientSetupForTest(t) if l != nil { defer l.Close() } trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() t.Fatalf("Unable to connect to %s: %s", addr.String(), err) } TransportTest(t, trans, trans) } func TestHttpClientHeaders(t *testing.T) { l, addr := HttpClientSetupForTest(t) if l != nil { defer l.Close() } trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() t.Fatalf("Unable to connect to %s: %s", addr.String(), err) } TransportHeaderTest(t, trans, trans) } func TestHttpCustomClient(t *testing.T) { l, addr := HttpClientSetupForTest(t) if l != nil { defer l.Close() } httpTransport := &customHttpTransport{} trans, err := NewTHttpPostClientWithOptions("http://"+addr.String(), THttpClientOptions{ Client: &http.Client{ Transport: httpTransport, }, }) if err != nil { l.Close() t.Fatalf("Unable to connect to %s: %s", addr.String(), err) } TransportHeaderTest(t, trans, trans) if !httpTransport.hit { t.Fatalf("Custom client was not used") } } func TestHttpCustomClientPackageScope(t *testing.T) { l, addr := HttpClientSetupForTest(t) if l != nil { defer l.Close() } httpTransport := &customHttpTransport{} DefaultHttpClient = &http.Client{ Transport: httpTransport, } trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() t.Fatalf("Unable to connect to %s: %s", addr.String(), err) } TransportHeaderTest(t, trans, trans) if !httpTransport.hit { t.Fatalf("Custom client was not used") } } type customHttpTransport struct { hit bool } func (c *customHttpTransport) RoundTrip(req *http.Request) (*http.Response, error) { c.hit = true return http.DefaultTransport.RoundTrip(req) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/http_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import "net/http" // NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function func NewThriftHandlerFunc(processor TProcessor, inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/x-thrift") transport := NewStreamTransport(r.Body, w) processor.Process(inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport)) } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/iostream_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bufio" "io" ) // StreamTransport is a Transport made of an io.Reader and/or an io.Writer type StreamTransport struct { io.Reader io.Writer isReadWriter bool closed bool } type StreamTransportFactory struct { Reader io.Reader Writer io.Writer isReadWriter bool } func (p *StreamTransportFactory) GetTransport(trans TTransport) TTransport { if trans != nil { t, ok := trans.(*StreamTransport) if ok { if t.isReadWriter { return NewStreamTransportRW(t.Reader.(io.ReadWriter)) } if t.Reader != nil && t.Writer != nil { return NewStreamTransport(t.Reader, t.Writer) } if t.Reader != nil && t.Writer == nil { return NewStreamTransportR(t.Reader) } if t.Reader == nil && t.Writer != nil { return NewStreamTransportW(t.Writer) } return &StreamTransport{} } } if p.isReadWriter { return NewStreamTransportRW(p.Reader.(io.ReadWriter)) } if p.Reader != nil && p.Writer != nil { return NewStreamTransport(p.Reader, p.Writer) } if p.Reader != nil && p.Writer == nil { return NewStreamTransportR(p.Reader) } if p.Reader == nil && p.Writer != nil { return NewStreamTransportW(p.Writer) } return &StreamTransport{} } func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory { return &StreamTransportFactory{Reader: reader, Writer: writer, isReadWriter: isReadWriter} } func NewStreamTransport(r io.Reader, w io.Writer) *StreamTransport { return &StreamTransport{Reader: bufio.NewReader(r), Writer: bufio.NewWriter(w)} } func NewStreamTransportR(r io.Reader) *StreamTransport { return &StreamTransport{Reader: bufio.NewReader(r)} } func NewStreamTransportW(w io.Writer) *StreamTransport { return &StreamTransport{Writer: bufio.NewWriter(w)} } func NewStreamTransportRW(rw io.ReadWriter) *StreamTransport { bufrw := bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw)) return &StreamTransport{Reader: bufrw, Writer: bufrw, isReadWriter: true} } func (p *StreamTransport) IsOpen() bool { return !p.closed } // implicitly opened on creation, can't be reopened once closed func (p *StreamTransport) Open() error { if !p.closed { return NewTTransportException(ALREADY_OPEN, "StreamTransport already open.") } else { return NewTTransportException(NOT_OPEN, "cannot reopen StreamTransport.") } } // Closes both the input and output streams. func (p *StreamTransport) Close() error { if p.closed { return NewTTransportException(NOT_OPEN, "StreamTransport already closed.") } p.closed = true closedReader := false if p.Reader != nil { c, ok := p.Reader.(io.Closer) if ok { e := c.Close() closedReader = true if e != nil { return e } } p.Reader = nil } if p.Writer != nil && (!closedReader || !p.isReadWriter) { c, ok := p.Writer.(io.Closer) if ok { e := c.Close() if e != nil { return e } } p.Writer = nil } return nil } // Flushes the underlying output stream if not null. func (p *StreamTransport) Flush() error { if p.Writer == nil { return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream") } f, ok := p.Writer.(Flusher) if ok { err := f.Flush() if err != nil { return NewTTransportExceptionFromError(err) } } return nil } func (p *StreamTransport) Read(c []byte) (n int, err error) { n, err = p.Reader.Read(c) if err != nil { err = NewTTransportExceptionFromError(err) } return } func (p *StreamTransport) ReadByte() (c byte, err error) { f, ok := p.Reader.(io.ByteReader) if ok { c, err = f.ReadByte() } else { c, err = readByte(p.Reader) } if err != nil { err = NewTTransportExceptionFromError(err) } return } func (p *StreamTransport) Write(c []byte) (n int, err error) { n, err = p.Writer.Write(c) if err != nil { err = NewTTransportExceptionFromError(err) } return } func (p *StreamTransport) WriteByte(c byte) (err error) { f, ok := p.Writer.(io.ByteWriter) if ok { err = f.WriteByte(c) } else { err = writeByte(p.Writer, c) } if err != nil { err = NewTTransportExceptionFromError(err) } return } func (p *StreamTransport) WriteString(s string) (n int, err error) { f, ok := p.Writer.(stringWriter) if ok { n, err = f.WriteString(s) } else { n, err = p.Writer.Write([]byte(s)) } if err != nil { err = NewTTransportExceptionFromError(err) } return } func (p *StreamTransport) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) return maxSize // the thruth is, we just don't know unless framed is used } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/iostream_transport_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "testing" ) func TestStreamTransport(t *testing.T) { trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024))) TransportTest(t, trans, trans) } func TestStreamTransportOpenClose(t *testing.T) { trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024))) if !trans.IsOpen() { t.Fatal("StreamTransport should be already open") } if trans.Open() == nil { t.Fatal("StreamTransport should return error when open twice") } if trans.Close() != nil { t.Fatal("StreamTransport should not return error when closing open transport") } if trans.IsOpen() { t.Fatal("StreamTransport should not be open after close") } if trans.Close() == nil { t.Fatal("StreamTransport should return error when closing a non open transport") } if trans.Open() == nil { t.Fatal("StreamTransport should not be able to reopen") } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/json_protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "encoding/base64" "fmt" ) const ( THRIFT_JSON_PROTOCOL_VERSION = 1 ) // for references to _ParseContext see tsimplejson_protocol.go // JSON protocol implementation for thrift. // // This protocol produces/consumes a simple output format // suitable for parsing by scripting languages. It should not be // confused with the full-featured TJSONProtocol. // type TJSONProtocol struct { *TSimpleJSONProtocol } // Constructor func NewTJSONProtocol(t TTransport) *TJSONProtocol { v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)} v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL)) v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL)) return v } // Factory type TJSONProtocolFactory struct{} func (p *TJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { return NewTJSONProtocol(trans) } func NewTJSONProtocolFactory() *TJSONProtocolFactory { return &TJSONProtocolFactory{} } func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { p.resetContextStack() // THRIFT-3735 if e := p.OutputListBegin(); e != nil { return e } if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil { return e } if e := p.WriteString(name); e != nil { return e } if e := p.WriteByte(int8(typeId)); e != nil { return e } if e := p.WriteI32(seqId); e != nil { return e } return nil } func (p *TJSONProtocol) WriteMessageEnd() error { return p.OutputListEnd() } func (p *TJSONProtocol) WriteStructBegin(name string) error { if e := p.OutputObjectBegin(); e != nil { return e } return nil } func (p *TJSONProtocol) WriteStructEnd() error { return p.OutputObjectEnd() } func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { if e := p.WriteI16(id); e != nil { return e } if e := p.OutputObjectBegin(); e != nil { return e } s, e1 := p.TypeIdToString(typeId) if e1 != nil { return e1 } if e := p.WriteString(s); e != nil { return e } return nil } func (p *TJSONProtocol) WriteFieldEnd() error { return p.OutputObjectEnd() } func (p *TJSONProtocol) WriteFieldStop() error { return nil } func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } s, e1 := p.TypeIdToString(keyType) if e1 != nil { return e1 } if e := p.WriteString(s); e != nil { return e } s, e1 = p.TypeIdToString(valueType) if e1 != nil { return e1 } if e := p.WriteString(s); e != nil { return e } if e := p.WriteI64(int64(size)); e != nil { return e } return p.OutputObjectBegin() } func (p *TJSONProtocol) WriteMapEnd() error { if e := p.OutputObjectEnd(); e != nil { return e } return p.OutputListEnd() } func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } func (p *TJSONProtocol) WriteListEnd() error { return p.OutputListEnd() } func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } func (p *TJSONProtocol) WriteSetEnd() error { return p.OutputListEnd() } func (p *TJSONProtocol) WriteBool(b bool) error { if b { return p.WriteI32(1) } return p.WriteI32(0) } func (p *TJSONProtocol) WriteByte(b int8) error { return p.WriteI32(int32(b)) } func (p *TJSONProtocol) WriteI16(v int16) error { return p.WriteI32(int32(v)) } func (p *TJSONProtocol) WriteI32(v int32) error { return p.OutputI64(int64(v)) } func (p *TJSONProtocol) WriteI64(v int64) error { return p.OutputI64(int64(v)) } func (p *TJSONProtocol) WriteDouble(v float64) error { return p.OutputF64(v) } func (p *TJSONProtocol) WriteString(v string) error { return p.OutputString(v) } func (p *TJSONProtocol) WriteBinary(v []byte) error { // JSON library only takes in a string, // not an arbitrary byte array, to ensure bytes are transmitted // efficiently we must convert this into a valid JSON string // therefore we use base64 encoding to avoid excessive escaping/quoting if e := p.OutputPreValue(); e != nil { return e } if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } writer := base64.NewEncoder(base64.StdEncoding, p.writer) if _, e := writer.Write(v); e != nil { p.writer.Reset(p.trans) // THRIFT-3735 return NewTProtocolException(e) } if e := writer.Close(); e != nil { return NewTProtocolException(e) } if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } return p.OutputPostValue() } // Reading methods. func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { p.resetContextStack() // THRIFT-3735 if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, typeId, seqId, err } version, err := p.ReadI32() if err != nil { return name, typeId, seqId, err } if version != THRIFT_JSON_PROTOCOL_VERSION { e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION) return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e) } if name, err = p.ReadString(); err != nil { return name, typeId, seqId, err } bTypeId, err := p.ReadByte() typeId = TMessageType(bTypeId) if err != nil { return name, typeId, seqId, err } if seqId, err = p.ReadI32(); err != nil { return name, typeId, seqId, err } return name, typeId, seqId, nil } func (p *TJSONProtocol) ReadMessageEnd() error { err := p.ParseListEnd() return err } func (p *TJSONProtocol) ReadStructBegin() (name string, err error) { _, err = p.ParseObjectStart() return "", err } func (p *TJSONProtocol) ReadStructEnd() error { return p.ParseObjectEnd() } func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { b, _ := p.reader.Peek(1) if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] { return "", STOP, -1, nil } fieldId, err := p.ReadI16() if err != nil { return "", STOP, fieldId, err } if _, err = p.ParseObjectStart(); err != nil { return "", STOP, fieldId, err } sType, err := p.ReadString() if err != nil { return "", STOP, fieldId, err } fType, err := p.StringToTypeId(sType) return "", fType, fieldId, err } func (p *TJSONProtocol) ReadFieldEnd() error { return p.ParseObjectEnd() } func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, VOID, 0, e } // read keyType sKeyType, e := p.ReadString() if e != nil { return keyType, valueType, size, e } keyType, e = p.StringToTypeId(sKeyType) if e != nil { return keyType, valueType, size, e } // read valueType sValueType, e := p.ReadString() if e != nil { return keyType, valueType, size, e } valueType, e = p.StringToTypeId(sValueType) if e != nil { return keyType, valueType, size, e } // read size iSize, e := p.ReadI64() if e != nil { return keyType, valueType, size, e } size = int(iSize) _, e = p.ParseObjectStart() return keyType, valueType, size, e } func (p *TJSONProtocol) ReadMapEnd() error { e := p.ParseObjectEnd() if e != nil { return e } return p.ParseListEnd() } func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { return p.ParseElemListBegin() } func (p *TJSONProtocol) ReadListEnd() error { return p.ParseListEnd() } func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { return p.ParseElemListBegin() } func (p *TJSONProtocol) ReadSetEnd() error { return p.ParseListEnd() } func (p *TJSONProtocol) ReadBool() (bool, error) { value, err := p.ReadI32() return (value != 0), err } func (p *TJSONProtocol) ReadByte() (int8, error) { v, err := p.ReadI64() return int8(v), err } func (p *TJSONProtocol) ReadI16() (int16, error) { v, err := p.ReadI64() return int16(v), err } func (p *TJSONProtocol) ReadI32() (int32, error) { v, err := p.ReadI64() return int32(v), err } func (p *TJSONProtocol) ReadI64() (int64, error) { v, _, err := p.ParseI64() return v, err } func (p *TJSONProtocol) ReadDouble() (float64, error) { v, _, err := p.ParseF64() return v, err } func (p *TJSONProtocol) ReadString() (string, error) { var v string if err := p.ParsePreValue(); err != nil { return v, err } f, _ := p.reader.Peek(1) if len(f) > 0 && f[0] == JSON_QUOTE { p.reader.ReadByte() value, err := p.ParseStringBody() v = value if err != nil { return v, err } } else if len(f) > 0 && f[0] == JSON_NULL[0] { b := make([]byte, len(JSON_NULL)) _, err := p.reader.Read(b) if err != nil { return v, NewTProtocolException(err) } if string(b) != string(JSON_NULL) { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } } else { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } return v, p.ParsePostValue() } func (p *TJSONProtocol) ReadBinary() ([]byte, error) { var v []byte if err := p.ParsePreValue(); err != nil { return nil, err } f, _ := p.reader.Peek(1) if len(f) > 0 && f[0] == JSON_QUOTE { p.reader.ReadByte() value, err := p.ParseBase64EncodedBody() v = value if err != nil { return v, err } } else if len(f) > 0 && f[0] == JSON_NULL[0] { b := make([]byte, len(JSON_NULL)) _, err := p.reader.Read(b) if err != nil { return v, NewTProtocolException(err) } if string(b) != string(JSON_NULL) { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } } else { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } return v, p.ParsePostValue() } func (p *TJSONProtocol) Flush() (err error) { err = p.writer.Flush() if err == nil { err = p.trans.Flush() } return NewTProtocolException(err) } func (p *TJSONProtocol) Skip(fieldType TType) (err error) { return SkipDefaultDepth(p, fieldType) } func (p *TJSONProtocol) Transport() TTransport { return p.trans } func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } s, e1 := p.TypeIdToString(elemType) if e1 != nil { return e1 } if e := p.WriteString(s); e != nil { return e } if e := p.WriteI64(int64(size)); e != nil { return e } return nil } func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } sElemType, err := p.ReadString() if err != nil { return VOID, size, err } elemType, err = p.StringToTypeId(sElemType) if err != nil { return elemType, size, err } nSize, err2 := p.ReadI64() size = int(nSize) return elemType, size, err2 } func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } sElemType, err := p.ReadString() if err != nil { return VOID, size, err } elemType, err = p.StringToTypeId(sElemType) if err != nil { return elemType, size, err } nSize, err2 := p.ReadI64() size = int(nSize) return elemType, size, err2 } func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } s, e1 := p.TypeIdToString(elemType) if e1 != nil { return e1 } if e := p.OutputString(s); e != nil { return e } if e := p.OutputI64(int64(size)); e != nil { return e } return nil } func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) { switch byte(fieldType) { case BOOL: return "tf", nil case BYTE: return "i8", nil case I16: return "i16", nil case I32: return "i32", nil case I64: return "i64", nil case DOUBLE: return "dbl", nil case STRING: return "str", nil case STRUCT: return "rec", nil case MAP: return "map", nil case SET: return "set", nil case LIST: return "lst", nil } e := fmt.Errorf("Unknown fieldType: %d", int(fieldType)) return "", NewTProtocolExceptionWithType(INVALID_DATA, e) } func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) { switch fieldType { case "tf": return TType(BOOL), nil case "i8": return TType(BYTE), nil case "i16": return TType(I16), nil case "i32": return TType(I32), nil case "i64": return TType(I64), nil case "dbl": return TType(DOUBLE), nil case "str": return TType(STRING), nil case "rec": return TType(STRUCT), nil case "map": return TType(MAP), nil case "set": return TType(SET), nil case "lst": return TType(LIST), nil } e := fmt.Errorf("Unknown type identifier: %s", fieldType) return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/json_protocol_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "encoding/base64" "encoding/json" "fmt" "math" "strconv" "testing" ) func TestWriteJSONProtocolBool(t *testing.T) { thetype := "boolean" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range BOOL_VALUES { if e := p.WriteBool(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() expected := "" if value { expected = "1" } else { expected = "0" } if s != expected { t.Fatalf("Bad value for %s %v: %s expected", thetype, value, s) } v := -1 if err := json.Unmarshal([]byte(s), &v); err != nil || (v != 0) != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadJSONProtocolBool(t *testing.T) { thetype := "boolean" for _, value := range BOOL_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) if value { trans.Write([]byte{'1'}) // not JSON_TRUE } else { trans.Write([]byte{'0'}) // not JSON_FALSE } trans.Flush() s := trans.String() v, e := p.ReadBool() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } vv := -1 if err := json.Unmarshal([]byte(s), &vv); err != nil || (vv != 0) != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, vv) } trans.Reset() trans.Close() } } func TestWriteJSONProtocolByte(t *testing.T) { thetype := "byte" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range BYTE_VALUES { if e := p.WriteByte(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int8(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadJSONProtocolByte(t *testing.T) { thetype := "byte" for _, value := range BYTE_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) trans.WriteString(strconv.Itoa(int(value))) trans.Flush() s := trans.String() v, e := p.ReadByte() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteJSONProtocolI16(t *testing.T) { thetype := "int16" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range INT16_VALUES { if e := p.WriteI16(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int16(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadJSONProtocolI16(t *testing.T) { thetype := "int16" for _, value := range INT16_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) trans.WriteString(strconv.Itoa(int(value))) trans.Flush() s := trans.String() v, e := p.ReadI16() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteJSONProtocolI32(t *testing.T) { thetype := "int32" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range INT32_VALUES { if e := p.WriteI32(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int32(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadJSONProtocolI32(t *testing.T) { thetype := "int32" for _, value := range INT32_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) trans.WriteString(strconv.Itoa(int(value))) trans.Flush() s := trans.String() v, e := p.ReadI32() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteJSONProtocolI64(t *testing.T) { thetype := "int64" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range INT64_VALUES { if e := p.WriteI64(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int64(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadJSONProtocolI64(t *testing.T) { thetype := "int64" for _, value := range INT64_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) trans.WriteString(strconv.FormatInt(value, 10)) trans.Flush() s := trans.String() v, e := p.ReadI64() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteJSONProtocolDouble(t *testing.T) { thetype := "double" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range DOUBLE_VALUES { if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if math.IsInf(value, 1) { if s != jsonQuote(JSON_INFINITY) { t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY)) } } else if math.IsInf(value, -1) { if s != jsonQuote(JSON_NEGATIVE_INFINITY) { t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) } } else if math.IsNaN(value) { if s != jsonQuote(JSON_NAN) { t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN)) } } else { if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := float64(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } } trans.Reset() } trans.Close() } func TestReadJSONProtocolDouble(t *testing.T) { thetype := "double" for _, value := range DOUBLE_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) n := NewNumericFromDouble(value) trans.WriteString(n.String()) trans.Flush() s := trans.String() v, e := p.ReadDouble() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if math.IsInf(value, 1) { if !math.IsInf(v, 1) { t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) } } else if math.IsInf(value, -1) { if !math.IsInf(v, -1) { t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) } } else if math.IsNaN(value) { if !math.IsNaN(v) { t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) } } else { if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } } trans.Reset() trans.Close() } } func TestWriteJSONProtocolString(t *testing.T) { thetype := "string" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range STRING_VALUES { if e := p.WriteString(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s[0] != '"' || s[len(s)-1] != '"' { t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\"")) } v := new(string) if err := json.Unmarshal([]byte(s), v); err != nil || *v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v) } trans.Reset() } trans.Close() } func TestReadJSONProtocolString(t *testing.T) { thetype := "string" for _, value := range STRING_VALUES { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) trans.WriteString(jsonQuote(value)) trans.Flush() s := trans.String() v, e := p.ReadString() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } v1 := new(string) if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) } trans.Reset() trans.Close() } } func TestWriteJSONProtocolBinary(t *testing.T) { thetype := "binary" value := protocol_bdata b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) base64.StdEncoding.Encode(b64value, value) b64String := string(b64value) trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) if e := p.WriteBinary(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() expectedString := fmt.Sprint("\"", b64String, "\"") if s != expectedString { t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString) } v1, err := p.ReadBinary() if err != nil { t.Fatalf("Unable to read binary: %s", err.Error()) } if len(v1) != len(value) { t.Fatalf("Invalid value for binary\nexpected: \"%v\"\n read: \"%v\"", value, v1) } for k, v := range value { if v1[k] != v { t.Fatalf("Invalid value for binary at %v\nexpected: \"%v\"\n read: \"%v\"", k, v, v1[k]) } } trans.Close() } func TestReadJSONProtocolBinary(t *testing.T) { thetype := "binary" value := protocol_bdata b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) base64.StdEncoding.Encode(b64value, value) b64String := string(b64value) trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) trans.WriteString(jsonQuote(b64String)) trans.Flush() s := trans.String() v, e := p.ReadBinary() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if len(v) != len(value) { t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) } for i := 0; i < len(v); i++ { if v[i] != value[i] { t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) } } v1 := new(string) if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) } trans.Reset() trans.Close() } func TestWriteJSONProtocolList(t *testing.T) { thetype := "list" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } p.WriteListEnd() if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } str := trans.String() str1 := new([]interface{}) err := json.Unmarshal([]byte(str), str1) if err != nil { t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) } l := *str1 if len(l) < 2 { t.Fatalf("List must be at least of length two to include metadata") } if l[0] != "dbl" { t.Fatal("Invalid type for list, expected: ", STRING, ", but was: ", l[0]) } if int(l[1].(float64)) != len(DOUBLE_VALUES) { t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) } for k, value := range DOUBLE_VALUES { s := l[k+2] if math.IsInf(value, 1) { if s.(string) != JSON_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) } } else if math.IsInf(value, 0) { if s.(string) != JSON_NEGATIVE_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) } } else if math.IsNaN(value) { if s.(string) != JSON_NAN { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) } } else { if s.(float64) != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) } } trans.Reset() } trans.Close() } func TestWriteJSONProtocolSet(t *testing.T) { thetype := "set" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } p.WriteSetEnd() if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } str := trans.String() str1 := new([]interface{}) err := json.Unmarshal([]byte(str), str1) if err != nil { t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) } l := *str1 if len(l) < 2 { t.Fatalf("Set must be at least of length two to include metadata") } if l[0] != "dbl" { t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0]) } if int(l[1].(float64)) != len(DOUBLE_VALUES) { t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) } for k, value := range DOUBLE_VALUES { s := l[k+2] if math.IsInf(value, 1) { if s.(string) != JSON_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) } } else if math.IsInf(value, 0) { if s.(string) != JSON_NEGATIVE_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) } } else if math.IsNaN(value) { if s.(string) != JSON_NAN { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) } } else { if s.(float64) != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) } } trans.Reset() } trans.Close() } func TestWriteJSONProtocolMap(t *testing.T) { thetype := "map" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) for k, value := range DOUBLE_VALUES { if e := p.WriteI32(int32(k)); e != nil { t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) } if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) } } p.WriteMapEnd() if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } str := trans.String() if str[0] != '[' || str[len(str)-1] != ']' { t.Fatalf("Bad value for %s, wrote: %q, in go: %q", thetype, str, DOUBLE_VALUES) } expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin() if err != nil { t.Fatalf("Error while reading map begin: %s", err.Error()) } if expectedKeyType != I32 { t.Fatal("Expected map key type ", I32, ", but was ", expectedKeyType) } if expectedValueType != DOUBLE { t.Fatal("Expected map value type ", DOUBLE, ", but was ", expectedValueType) } if expectedSize != len(DOUBLE_VALUES) { t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize) } for k, value := range DOUBLE_VALUES { ik, err := p.ReadI32() if err != nil { t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error()) } if int(ik) != k { t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k) } dv, err := p.ReadDouble() if err != nil { t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error()) } s := strconv.FormatFloat(dv, 'g', 10, 64) if math.IsInf(value, 1) { if !math.IsInf(dv, 1) { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY)) } } else if math.IsInf(value, 0) { if !math.IsInf(dv, 0) { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) } } else if math.IsNaN(value) { if !math.IsNaN(dv) { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN)) } } else { expected := strconv.FormatFloat(value, 'g', 10, 64) if s != expected { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected) } v := float64(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } } } err = p.ReadMapEnd() if err != nil { t.Fatalf("Error while reading map end: %s", err.Error()) } trans.Close() } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/lowlevel_benchmarks_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "testing" ) var binaryProtoF = NewTBinaryProtocolFactoryDefault() var compactProtoF = NewTCompactProtocolFactory() var buf = bytes.NewBuffer(make([]byte, 0, 1024)) var tfv = []TTransportFactory{ NewTMemoryBufferTransportFactory(1024), NewStreamTransportFactory(buf, buf, true), NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), } func BenchmarkBinaryBool_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) } } func BenchmarkBinaryByte_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) } } func BenchmarkBinaryI16_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) } } func BenchmarkBinaryI32_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkBinaryI64_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkBinaryDouble_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkBinaryString_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkBinaryBinary_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) } } func BenchmarkBinaryBool_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) } } func BenchmarkBinaryByte_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) } } func BenchmarkBinaryI16_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) } } func BenchmarkBinaryI32_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkBinaryI64_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkBinaryDouble_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkBinaryString_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkBinaryBinary_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) } } func BenchmarkBinaryBool_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) } } func BenchmarkBinaryByte_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) } } func BenchmarkBinaryI16_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) } } func BenchmarkBinaryI32_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkBinaryI64_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkBinaryDouble_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkBinaryString_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkBinaryBinary_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) } } func BenchmarkCompactBool_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) } } func BenchmarkCompactByte_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) } } func BenchmarkCompactI16_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) } } func BenchmarkCompactI32_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkCompactI64_0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkCompactDouble0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkCompactString0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkCompactBinary0(b *testing.B) { trans := tfv[0].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) } } func BenchmarkCompactBool_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) } } func BenchmarkCompactByte_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) } } func BenchmarkCompactI16_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) } } func BenchmarkCompactI32_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkCompactI64_1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkCompactDouble1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkCompactString1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkCompactBinary1(b *testing.B) { trans := tfv[1].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) } } func BenchmarkCompactBool_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) } } func BenchmarkCompactByte_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) } } func BenchmarkCompactI16_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) } } func BenchmarkCompactI32_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkCompactI64_2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkCompactDouble2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkCompactString2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkCompactBinary2(b *testing.B) { trans := tfv[2].GetTransport(nil) p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/memory_buffer.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" ) // Memory buffer-based implementation of the TTransport interface. type TMemoryBuffer struct { *bytes.Buffer size int } type TMemoryBufferTransportFactory struct { size int } func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) TTransport { if trans != nil { t, ok := trans.(*TMemoryBuffer) if ok && t.size > 0 { return NewTMemoryBufferLen(t.size) } } return NewTMemoryBufferLen(p.size) } func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory { return &TMemoryBufferTransportFactory{size: size} } func NewTMemoryBuffer() *TMemoryBuffer { return &TMemoryBuffer{Buffer: &bytes.Buffer{}, size: 0} } func NewTMemoryBufferLen(size int) *TMemoryBuffer { buf := make([]byte, 0, size) return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size} } func (p *TMemoryBuffer) IsOpen() bool { return true } func (p *TMemoryBuffer) Open() error { return nil } func (p *TMemoryBuffer) Close() error { p.Buffer.Reset() return nil } // Flushing a memory buffer is a no-op func (p *TMemoryBuffer) Flush() error { return nil } func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) { return uint64(p.Buffer.Len()) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/memory_buffer_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "testing" ) func TestMemoryBuffer(t *testing.T) { trans := NewTMemoryBufferLen(1024) TransportTest(t, trans, trans) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/messagetype.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Message type constants in the Thrift protocol. type TMessageType int32 const ( INVALID_TMESSAGE_TYPE TMessageType = 0 CALL TMessageType = 1 REPLY TMessageType = 2 EXCEPTION TMessageType = 3 ONEWAY TMessageType = 4 ) ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/multiplexed_protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "fmt" "strings" ) /* TMultiplexedProtocol is a protocol-independent concrete decorator that allows a Thrift client to communicate with a multiplexing Thrift server, by prepending the service name to the function name during function calls. NOTE: THIS IS NOT USED BY SERVERS. On the server, use TMultiplexedProcessor to handle request from a multiplexing client. This example uses a single socket transport to invoke two services: socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) transport := thrift.NewTFramedTransport(socket) protocol := thrift.NewTBinaryProtocolTransport(transport) mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator") service := Calculator.NewCalculatorClient(mp) mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport") service2 := WeatherReport.NewWeatherReportClient(mp2) err := transport.Open() if err != nil { t.Fatal("Unable to open client socket", err) } fmt.Println(service.Add(2,2)) fmt.Println(service2.GetTemperature()) */ type TMultiplexedProtocol struct { TProtocol serviceName string } const MULTIPLEXED_SEPARATOR = ":" func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol { return &TMultiplexedProtocol{ TProtocol: protocol, serviceName: serviceName, } } func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { if typeId == CALL || typeId == ONEWAY { return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid) } else { return t.TProtocol.WriteMessageBegin(name, typeId, seqid) } } /* TMultiplexedProcessor is a TProcessor allowing a single TServer to provide multiple services. To do so, you instantiate the processor and then register additional processors with it, as shown in the following example: var processor = thrift.NewTMultiplexedProcessor() firstProcessor := processor.RegisterProcessor("FirstService", firstProcessor) processor.registerProcessor( "Calculator", Calculator.NewCalculatorProcessor(&CalculatorHandler{}), ) processor.registerProcessor( "WeatherReport", WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}), ) serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT) if err != nil { t.Fatal("Unable to create server socket", err) } server := thrift.NewTSimpleServer2(processor, serverTransport) server.Serve(); */ type TMultiplexedProcessor struct { serviceProcessorMap map[string]TProcessor DefaultProcessor TProcessor } func NewTMultiplexedProcessor() *TMultiplexedProcessor { return &TMultiplexedProcessor{ serviceProcessorMap: make(map[string]TProcessor), } } func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) { t.DefaultProcessor = processor } func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) { if t.serviceProcessorMap == nil { t.serviceProcessorMap = make(map[string]TProcessor) } t.serviceProcessorMap[name] = processor } func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) { name, typeId, seqid, err := in.ReadMessageBegin() if err != nil { return false, err } if typeId != CALL && typeId != ONEWAY { return false, fmt.Errorf("Unexpected message type %v", typeId) } //extract the service name v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2) if len(v) != 2 { if t.DefaultProcessor != nil { smb := NewStoredMessageProtocol(in, name, typeId, seqid) return t.DefaultProcessor.Process(smb, out) } return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name) } actualProcessor, ok := t.serviceProcessorMap[v[0]] if !ok { return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0]) } smb := NewStoredMessageProtocol(in, v[1], typeId, seqid) return actualProcessor.Process(smb, out) } //Protocol that use stored message for ReadMessageBegin type storedMessageProtocol struct { TProtocol name string typeId TMessageType seqid int32 } func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol { return &storedMessageProtocol{protocol, name, typeId, seqid} } func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { return s.name, s.typeId, s.seqid, nil } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/numeric.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "math" "strconv" ) type Numeric interface { Int64() int64 Int32() int32 Int16() int16 Byte() byte Int() int Float64() float64 Float32() float32 String() string isNull() bool } type numeric struct { iValue int64 dValue float64 sValue string isNil bool } var ( INFINITY Numeric NEGATIVE_INFINITY Numeric NAN Numeric ZERO Numeric NUMERIC_NULL Numeric ) func NewNumericFromDouble(dValue float64) Numeric { if math.IsInf(dValue, 1) { return INFINITY } if math.IsInf(dValue, -1) { return NEGATIVE_INFINITY } if math.IsNaN(dValue) { return NAN } iValue := int64(dValue) sValue := strconv.FormatFloat(dValue, 'g', 10, 64) isNil := false return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} } func NewNumericFromI64(iValue int64) Numeric { dValue := float64(iValue) sValue := string(iValue) isNil := false return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} } func NewNumericFromI32(iValue int32) Numeric { dValue := float64(iValue) sValue := string(iValue) isNil := false return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil} } func NewNumericFromString(sValue string) Numeric { if sValue == INFINITY.String() { return INFINITY } if sValue == NEGATIVE_INFINITY.String() { return NEGATIVE_INFINITY } if sValue == NAN.String() { return NAN } iValue, _ := strconv.ParseInt(sValue, 10, 64) dValue, _ := strconv.ParseFloat(sValue, 64) isNil := len(sValue) == 0 return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil} } func NewNumericFromJSONString(sValue string, isNull bool) Numeric { if isNull { return NewNullNumeric() } if sValue == JSON_INFINITY { return INFINITY } if sValue == JSON_NEGATIVE_INFINITY { return NEGATIVE_INFINITY } if sValue == JSON_NAN { return NAN } iValue, _ := strconv.ParseInt(sValue, 10, 64) dValue, _ := strconv.ParseFloat(sValue, 64) return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNull} } func NewNullNumeric() Numeric { return &numeric{iValue: 0, dValue: 0.0, sValue: "", isNil: true} } func (p *numeric) Int64() int64 { return p.iValue } func (p *numeric) Int32() int32 { return int32(p.iValue) } func (p *numeric) Int16() int16 { return int16(p.iValue) } func (p *numeric) Byte() byte { return byte(p.iValue) } func (p *numeric) Int() int { return int(p.iValue) } func (p *numeric) Float64() float64 { return p.dValue } func (p *numeric) Float32() float32 { return float32(p.dValue) } func (p *numeric) String() string { return p.sValue } func (p *numeric) isNull() bool { return p.isNil } func init() { INFINITY = &numeric{iValue: 0, dValue: math.Inf(1), sValue: "Infinity", isNil: false} NEGATIVE_INFINITY = &numeric{iValue: 0, dValue: math.Inf(-1), sValue: "-Infinity", isNil: false} NAN = &numeric{iValue: 0, dValue: math.NaN(), sValue: "NaN", isNil: false} ZERO = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: false} NUMERIC_NULL = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: true} } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/pointerize.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift /////////////////////////////////////////////////////////////////////////////// // This file is home to helpers that convert from various base types to // respective pointer types. This is necessary because Go does not permit // references to constants, nor can a pointer type to base type be allocated // and initialized in a single expression. // // E.g., this is not allowed: // // var ip *int = &5 // // But this *is* allowed: // // func IntPtr(i int) *int { return &i } // var ip *int = IntPtr(5) // // Since pointers to base types are commonplace as [optional] fields in // exported thrift structs, we factor such helpers here. /////////////////////////////////////////////////////////////////////////////// func Float32Ptr(v float32) *float32 { return &v } func Float64Ptr(v float64) *float64 { return &v } func IntPtr(v int) *int { return &v } func Int32Ptr(v int32) *int32 { return &v } func Int64Ptr(v int64) *int64 { return &v } func StringPtr(v string) *string { return &v } func Uint32Ptr(v uint32) *uint32 { return &v } func Uint64Ptr(v uint64) *uint64 { return &v } func BoolPtr(v bool) *bool { return &v } func ByteSlicePtr(v []byte) *[]byte { return &v } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/processor.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // A processor is a generic object which operates upon an input stream and // writes to some output stream. type TProcessor interface { Process(in, out TProtocol) (bool, TException) } type TProcessorFunction interface { Process(seqId int32, in, out TProtocol) (bool, TException) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/processor_factory.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // The default processor factory just returns a singleton // instance. type TProcessorFactory interface { GetProcessor(trans TTransport) TProcessor } type tProcessorFactory struct { processor TProcessor } func NewTProcessorFactory(p TProcessor) TProcessorFactory { return &tProcessorFactory{processor: p} } func (p *tProcessorFactory) GetProcessor(trans TTransport) TProcessor { return p.processor } /** * The default processor factory just returns a singleton * instance. */ type TProcessorFunctionFactory interface { GetProcessorFunction(trans TTransport) TProcessorFunction } type tProcessorFunctionFactory struct { processor TProcessorFunction } func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactory { return &tProcessorFunctionFactory{processor: p} } func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction { return p.processor } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "errors" "fmt" ) const ( VERSION_MASK = 0xffff0000 VERSION_1 = 0x80010000 ) type TProtocol interface { WriteMessageBegin(name string, typeId TMessageType, seqid int32) error WriteMessageEnd() error WriteStructBegin(name string) error WriteStructEnd() error WriteFieldBegin(name string, typeId TType, id int16) error WriteFieldEnd() error WriteFieldStop() error WriteMapBegin(keyType TType, valueType TType, size int) error WriteMapEnd() error WriteListBegin(elemType TType, size int) error WriteListEnd() error WriteSetBegin(elemType TType, size int) error WriteSetEnd() error WriteBool(value bool) error WriteByte(value int8) error WriteI16(value int16) error WriteI32(value int32) error WriteI64(value int64) error WriteDouble(value float64) error WriteString(value string) error WriteBinary(value []byte) error ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) ReadMessageEnd() error ReadStructBegin() (name string, err error) ReadStructEnd() error ReadFieldBegin() (name string, typeId TType, id int16, err error) ReadFieldEnd() error ReadMapBegin() (keyType TType, valueType TType, size int, err error) ReadMapEnd() error ReadListBegin() (elemType TType, size int, err error) ReadListEnd() error ReadSetBegin() (elemType TType, size int, err error) ReadSetEnd() error ReadBool() (value bool, err error) ReadByte() (value int8, err error) ReadI16() (value int16, err error) ReadI32() (value int32, err error) ReadI64() (value int64, err error) ReadDouble() (value float64, err error) ReadString() (value string, err error) ReadBinary() (value []byte, err error) Skip(fieldType TType) (err error) Flush() (err error) Transport() TTransport } // The maximum recursive depth the skip() function will traverse const DEFAULT_RECURSION_DEPTH = 64 // Skips over the next data element from the provided input TProtocol object. func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) { return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH) } // Skips over the next data element from the provided input TProtocol object. func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { if maxDepth <= 0 { return NewTProtocolExceptionWithType( DEPTH_LIMIT, errors.New("Depth limit exceeded")) } switch fieldType { case STOP: return case BOOL: _, err = self.ReadBool() return case BYTE: _, err = self.ReadByte() return case I16: _, err = self.ReadI16() return case I32: _, err = self.ReadI32() return case I64: _, err = self.ReadI64() return case DOUBLE: _, err = self.ReadDouble() return case STRING: _, err = self.ReadString() return case STRUCT: if _, err = self.ReadStructBegin(); err != nil { return err } for { _, typeId, _, _ := self.ReadFieldBegin() if typeId == STOP { break } err := Skip(self, typeId, maxDepth-1) if err != nil { return err } self.ReadFieldEnd() } return self.ReadStructEnd() case MAP: keyType, valueType, size, err := self.ReadMapBegin() if err != nil { return err } for i := 0; i < size; i++ { err := Skip(self, keyType, maxDepth-1) if err != nil { return err } err = Skip(self, valueType, maxDepth-1) if err != nil { return err } } return self.ReadMapEnd() case SET: elemType, size, err := self.ReadSetBegin() if err != nil { return err } for i := 0; i < size; i++ { err := Skip(self, elemType, maxDepth-1) if err != nil { return err } } return self.ReadSetEnd() case LIST: elemType, size, err := self.ReadListBegin() if err != nil { return err } for i := 0; i < size; i++ { err := Skip(self, elemType, maxDepth-1) if err != nil { return err } } return self.ReadListEnd() default: return NewTProtocolExceptionWithType(INVALID_DATA, fmt.Errorf("Unknown data type %d", fieldType)) } return nil } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/protocol_exception.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "encoding/base64" ) // Thrift Protocol exception type TProtocolException interface { TException TypeId() int } const ( UNKNOWN_PROTOCOL_EXCEPTION = 0 INVALID_DATA = 1 NEGATIVE_SIZE = 2 SIZE_LIMIT = 3 BAD_VERSION = 4 NOT_IMPLEMENTED = 5 DEPTH_LIMIT = 6 ) type tProtocolException struct { typeId int message string } func (p *tProtocolException) TypeId() int { return p.typeId } func (p *tProtocolException) String() string { return p.message } func (p *tProtocolException) Error() string { return p.message } func NewTProtocolException(err error) TProtocolException { if err == nil { return nil } if e,ok := err.(TProtocolException); ok { return e } if _, ok := err.(base64.CorruptInputError); ok { return &tProtocolException{INVALID_DATA, err.Error()} } return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()} } func NewTProtocolExceptionWithType(errType int, err error) TProtocolException { if err == nil { return nil } return &tProtocolException{errType, err.Error()} } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/protocol_factory.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Factory interface for constructing protocol instances. type TProtocolFactory interface { GetProtocol(trans TTransport) TProtocol } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/protocol_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "io/ioutil" "math" "net" "net/http" "testing" ) const PROTOCOL_BINARY_DATA_SIZE = 155 var ( data string // test data for writing protocol_bdata []byte // test data for writing; same as data BOOL_VALUES []bool BYTE_VALUES []int8 INT16_VALUES []int16 INT32_VALUES []int32 INT64_VALUES []int64 DOUBLE_VALUES []float64 STRING_VALUES []string ) func init() { protocol_bdata = make([]byte, PROTOCOL_BINARY_DATA_SIZE) for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ { protocol_bdata[i] = byte((i + 'a') % 255) } data = string(protocol_bdata) BOOL_VALUES = []bool{false, true, false, false, true} BYTE_VALUES = []int8{117, 0, 1, 32, 127, -128, -1} INT16_VALUES = []int16{459, 0, 1, -1, -128, 127, 32767, -32768} INT32_VALUES = []int32{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535} INT64_VALUES = []int64{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535, 34359738481, -35184372088719, -9223372036854775808, 9223372036854775807} DOUBLE_VALUES = []float64{459.3, 0.0, -1.0, 1.0, 0.5, 0.3333, 3.14159, 1.537e-38, 1.673e25, 6.02214179e23, -6.02214179e23, INFINITY.Float64(), NEGATIVE_INFINITY.Float64(), NAN.Float64()} STRING_VALUES = []string{"", "a", "st[uf]f", "st,u:ff with spaces", "stuff\twith\nescape\\characters'...\"lots{of}fun"} } type HTTPEchoServer struct{} type HTTPHeaderEchoServer struct{} func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { buf, err := ioutil.ReadAll(req.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write(buf) } else { w.WriteHeader(http.StatusOK) w.Write(buf) } } func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { buf, err := ioutil.ReadAll(req.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write(buf) } else { w.WriteHeader(http.StatusOK) w.Write(buf) } } func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) { addr, err := FindAvailableTCPServerPort(40000) if err != nil { t.Fatalf("Unable to find available tcp port addr: %s", err) return nil, addr } l, err := net.Listen(addr.Network(), addr.String()) if err != nil { t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) return l, addr } go http.Serve(l, &HTTPEchoServer{}) return l, addr } func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) { addr, err := FindAvailableTCPServerPort(40000) if err != nil { t.Fatalf("Unable to find available tcp port addr: %s", err) return nil, addr } l, err := net.Listen(addr.Network(), addr.String()) if err != nil { t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) return l, addr } go http.Serve(l, &HTTPHeaderEchoServer{}) return l, addr } func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { buf := bytes.NewBuffer(make([]byte, 0, 1024)) l, addr := HttpClientSetupForTest(t) defer l.Close() transports := []TTransportFactory{ NewTMemoryBufferTransportFactory(1024), NewStreamTransportFactory(buf, buf, true), NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), NewTHttpPostClientTransportFactory("http://" + addr.String()), } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteBool(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteByte(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteI16(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteI32(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteI64(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteDouble(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteString(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteBinary(t, p, trans) trans.Close() } for _, tf := range transports { trans := tf.GetTransport(nil) p := protocolFactory.GetProtocol(trans) ReadWriteI64(t, p, trans) ReadWriteDouble(t, p, trans) ReadWriteBinary(t, p, trans) ReadWriteByte(t, p, trans) trans.Close() } } func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(BOOL) thelen := len(BOOL_VALUES) err := p.WriteListBegin(thetype, thelen) if err != nil { t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype) } for k, v := range BOOL_VALUES { err = p.WriteBool(v) if err != nil { t.Errorf("%s: %T %T %q Error writing bool in list at index %d: %q", "ReadWriteBool", p, trans, err, k, v) } } p.WriteListEnd() if err != nil { t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteBool", p, trans, err, BOOL_VALUES) } p.Flush() thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteBool", p, trans, err, BOOL_VALUES) } _, ok := p.(*TSimpleJSONProtocol) if !ok { if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteBool", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteBool", p, trans, thelen, thelen2) } } for k, v := range BOOL_VALUES { value, err := p.ReadBool() if err != nil { t.Errorf("%s: %T %T %q Error reading bool at index %d: %q", "ReadWriteBool", p, trans, err, k, v) } if v != value { t.Errorf("%s: index %d %q %q %q != %q", "ReadWriteBool", k, p, trans, v, value) } } err = p.ReadListEnd() if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err) } } func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(BYTE) thelen := len(BYTE_VALUES) err := p.WriteListBegin(thetype, thelen) if err != nil { t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype) } for k, v := range BYTE_VALUES { err = p.WriteByte(v) if err != nil { t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v) } } err = p.WriteListEnd() if err != nil { t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) } err = p.Flush() if err != nil { t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) } thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) } _, ok := p.(*TSimpleJSONProtocol) if !ok { if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteByte", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteByte", p, trans, thelen, thelen2) } } for k, v := range BYTE_VALUES { value, err := p.ReadByte() if err != nil { t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v) } if v != value { t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value) } } err = p.ReadListEnd() if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err) } } func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I16) thelen := len(INT16_VALUES) p.WriteListBegin(thetype, thelen) for _, v := range INT16_VALUES { p.WriteI16(v) } p.WriteListEnd() p.Flush() thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES) } _, ok := p.(*TSimpleJSONProtocol) if !ok { if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI16", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI16", p, trans, thelen, thelen2) } } for k, v := range INT16_VALUES { value, err := p.ReadI16() if err != nil { t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v) } if v != value { t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value) } } err = p.ReadListEnd() if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err) } } func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I32) thelen := len(INT32_VALUES) p.WriteListBegin(thetype, thelen) for _, v := range INT32_VALUES { p.WriteI32(v) } p.WriteListEnd() p.Flush() thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES) } _, ok := p.(*TSimpleJSONProtocol) if !ok { if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI32", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI32", p, trans, thelen, thelen2) } } for k, v := range INT32_VALUES { value, err := p.ReadI32() if err != nil { t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v) } if v != value { t.Errorf("%s: %T %T %d != %d", "ReadWriteI32", p, trans, v, value) } } if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI32", p, trans, err) } } func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I64) thelen := len(INT64_VALUES) p.WriteListBegin(thetype, thelen) for _, v := range INT64_VALUES { p.WriteI64(v) } p.WriteListEnd() p.Flush() thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES) } _, ok := p.(*TSimpleJSONProtocol) if !ok { if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI64", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI64", p, trans, thelen, thelen2) } } for k, v := range INT64_VALUES { value, err := p.ReadI64() if err != nil { t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v) } if v != value { t.Errorf("%s: %T %T %q != %q", "ReadWriteI64", p, trans, v, value) } } if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI64", p, trans, err) } } func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(DOUBLE) thelen := len(DOUBLE_VALUES) p.WriteListBegin(thetype, thelen) for _, v := range DOUBLE_VALUES { p.WriteDouble(v) } p.WriteListEnd() p.Flush() thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES) } if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteDouble", p, trans, thelen, thelen2) } for k, v := range DOUBLE_VALUES { value, err := p.ReadDouble() if err != nil { t.Errorf("%s: %T %T %q Error reading double at index %d: %q", "ReadWriteDouble", p, trans, err, k, v) } if math.IsNaN(v) { if !math.IsNaN(value) { t.Errorf("%s: %T %T math.IsNaN(%q) != math.IsNaN(%q)", "ReadWriteDouble", p, trans, v, value) } } else if v != value { t.Errorf("%s: %T %T %v != %q", "ReadWriteDouble", p, trans, v, value) } } err = p.ReadListEnd() if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err) } } func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(STRING) thelen := len(STRING_VALUES) p.WriteListBegin(thetype, thelen) for _, v := range STRING_VALUES { p.WriteString(v) } p.WriteListEnd() p.Flush() thetype2, thelen2, err := p.ReadListBegin() if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES) } _, ok := p.(*TSimpleJSONProtocol) if !ok { if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteString", p, trans, thetype, thetype2) } if thelen != thelen2 { t.Errorf("%s: %T %T len %s != len %s", "ReadWriteString", p, trans, thelen, thelen2) } } for k, v := range STRING_VALUES { value, err := p.ReadString() if err != nil { t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v) } if v != value { t.Errorf("%s: %T %T %d != %d", "ReadWriteString", p, trans, v, value) } } if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteString", p, trans, err) } } func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) { v := protocol_bdata p.WriteBinary(v) p.Flush() value, err := p.ReadBinary() if err != nil { t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error()) } if len(v) != len(value) { t.Errorf("%s: %T %T len(v) != len(value)... %d != %d", "ReadWriteBinary", p, trans, len(v), len(value)) } else { for i := 0; i < len(v); i++ { if v[i] != value[i] { t.Errorf("%s: %T %T %s != %s", "ReadWriteBinary", p, trans, v, value) } } } } func UnmatchedBeginEndProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { // NOTE: not all protocol implementations do strict state check to // return an error on unmatched Begin/End calls. // This test is only meant to make sure that those unmatched Begin/End // calls won't cause panic. There's no real "test" here. trans := NewTMemoryBuffer() t.Run("Read", func(t *testing.T) { t.Run("Message", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.ReadMessageEnd() p.ReadMessageEnd() }) t.Run("Struct", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.ReadStructEnd() p.ReadStructEnd() }) t.Run("Field", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.ReadFieldEnd() p.ReadFieldEnd() }) t.Run("Map", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.ReadMapEnd() p.ReadMapEnd() }) t.Run("List", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.ReadListEnd() p.ReadListEnd() }) t.Run("Set", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.ReadSetEnd() p.ReadSetEnd() }) }) t.Run("Write", func(t *testing.T) { t.Run("Message", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.WriteMessageEnd() p.WriteMessageEnd() }) t.Run("Struct", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.WriteStructEnd() p.WriteStructEnd() }) t.Run("Field", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.WriteFieldEnd() p.WriteFieldEnd() }) t.Run("Map", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.WriteMapEnd() p.WriteMapEnd() }) t.Run("List", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.WriteListEnd() p.WriteListEnd() }) t.Run("Set", func(t *testing.T) { trans.Reset() p := protocolFactory.GetProtocol(trans) p.WriteSetEnd() p.WriteSetEnd() }) }) trans.Close() } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/rich_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import "io" type RichTransport struct { TTransport } // Wraps Transport to provide TRichTransport interface func NewTRichTransport(trans TTransport) *RichTransport { return &RichTransport{trans} } func (r *RichTransport) ReadByte() (c byte, err error) { return readByte(r.TTransport) } func (r *RichTransport) WriteByte(c byte) error { return writeByte(r.TTransport, c) } func (r *RichTransport) WriteString(s string) (n int, err error) { return r.Write([]byte(s)) } func (r *RichTransport) RemainingBytes() (num_bytes uint64) { return r.TTransport.RemainingBytes() } func readByte(r io.Reader) (c byte, err error) { v := [1]byte{0} n, err := r.Read(v[0:1]) if n > 0 && (err == nil || err == io.EOF) { return v[0], nil } if n > 0 && err != nil { return v[0], err } if err != nil { return 0, err } return v[0], nil } func writeByte(w io.Writer, c byte) error { v := [1]byte{c} _, err := w.Write(v[0:1]) return err } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/rich_transport_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bytes" "errors" "io" "reflect" "testing" ) func TestEnsureTransportsAreRich(t *testing.T) { buf := bytes.NewBuffer(make([]byte, 0, 1024)) transports := []TTransportFactory{ NewTMemoryBufferTransportFactory(1024), NewStreamTransportFactory(buf, buf, true), NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), NewTHttpPostClientTransportFactory("http://127.0.0.1"), } for _, tf := range transports { trans := tf.GetTransport(nil) _, ok := trans.(TRichTransport) if !ok { t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans)) } } } // TestReadByte tests whether readByte handles error cases correctly. func TestReadByte(t *testing.T) { for i, test := range readByteTests { v, err := readByte(test.r) if v != test.v { t.Fatalf("TestReadByte %d: value differs. Expected %d, got %d", i, test.v, test.r.v) } if err != test.err { t.Fatalf("TestReadByte %d: error differs. Expected %s, got %s", i, test.err, test.r.err) } } } var someError = errors.New("Some error") var readByteTests = []struct { r *mockReader v byte err error }{ {&mockReader{0, 55, io.EOF}, 0, io.EOF}, // reader sends EOF w/o data {&mockReader{0, 55, someError}, 0, someError}, // reader sends some other error {&mockReader{1, 55, nil}, 55, nil}, // reader sends data w/o error {&mockReader{1, 55, io.EOF}, 55, nil}, // reader sends data with EOF {&mockReader{1, 55, someError}, 55, someError}, // reader sends data withsome error } type mockReader struct { n int v byte err error } func (r *mockReader) Read(p []byte) (n int, err error) { if r.n > 0 { p[0] = r.v } return r.n, r.err } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/serializer.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift type TSerializer struct { Transport *TMemoryBuffer Protocol TProtocol } type TStruct interface { Write(p TProtocol) error Read(p TProtocol) error } func NewTSerializer() *TSerializer { transport := NewTMemoryBufferLen(1024) protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport) return &TSerializer{ transport, protocol} } func (t *TSerializer) WriteString(msg TStruct) (s string, err error) { t.Transport.Reset() if err = msg.Write(t.Protocol); err != nil { return } if err = t.Protocol.Flush(); err != nil { return } if err = t.Transport.Flush(); err != nil { return } return t.Transport.String(), nil } func (t *TSerializer) Write(msg TStruct) (b []byte, err error) { t.Transport.Reset() if err = msg.Write(t.Protocol); err != nil { return } if err = t.Protocol.Flush(); err != nil { return } if err = t.Transport.Flush(); err != nil { return } b = append(b, t.Transport.Bytes()...) return } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/serializer_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "errors" "fmt" "testing" ) type ProtocolFactory interface { GetProtocol(t TTransport) TProtocol } func compareStructs(m, m1 MyTestStruct) (bool, error) { switch { case m.On != m1.On: return false, errors.New("Boolean not equal") case m.B != m1.B: return false, errors.New("Byte not equal") case m.Int16 != m1.Int16: return false, errors.New("Int16 not equal") case m.Int32 != m1.Int32: return false, errors.New("Int32 not equal") case m.Int64 != m1.Int64: return false, errors.New("Int64 not equal") case m.D != m1.D: return false, errors.New("Double not equal") case m.St != m1.St: return false, errors.New("String not equal") case len(m.Bin) != len(m1.Bin): return false, errors.New("Binary size not equal") case len(m.Bin) == len(m1.Bin): for i := range m.Bin { if m.Bin[i] != m1.Bin[i] { return false, errors.New("Binary not equal") } } case len(m.StringMap) != len(m1.StringMap): return false, errors.New("StringMap size not equal") case len(m.StringList) != len(m1.StringList): return false, errors.New("StringList size not equal") case len(m.StringSet) != len(m1.StringSet): return false, errors.New("StringSet size not equal") case m.E != m1.E: return false, errors.New("MyTestEnum not equal") default: return true, nil } return true, nil } func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) { t := NewTSerializer() t.Protocol = pf.GetProtocol(t.Transport) var m = MyTestStruct{} m.On = true m.B = int8(0) m.Int16 = 1 m.Int32 = 2 m.Int64 = 3 m.D = 4.1 m.St = "Test" m.Bin = make([]byte, 10) m.StringMap = make(map[string]string, 5) m.StringList = make([]string, 5) m.StringSet = make(map[string]struct{}, 5) m.E = 2 s, err := t.WriteString(&m) if err != nil { return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err)) } t1 := NewTDeserializer() t1.Protocol = pf.GetProtocol(t1.Transport) var m1 = MyTestStruct{} if err = t1.ReadString(&m1, s); err != nil { return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err)) } return compareStructs(m, m1) } func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) { t := NewTSerializer() t.Protocol = pf.GetProtocol(t.Transport) var m = MyTestStruct{} m.On = false m.B = int8(0) m.Int16 = 1 m.Int32 = 2 m.Int64 = 3 m.D = 4.1 m.St = "Test" m.Bin = make([]byte, 10) m.StringMap = make(map[string]string, 5) m.StringList = make([]string, 5) m.StringSet = make(map[string]struct{}, 5) m.E = 2 s, err := t.WriteString(&m) if err != nil { return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err)) } t1 := NewTDeserializer() t1.Protocol = pf.GetProtocol(t1.Transport) var m1 = MyTestStruct{} if err = t1.ReadString(&m1, s); err != nil { return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err)) } return compareStructs(m, m1) } func TestSerializer(t *testing.T) { var protocol_factories map[string]ProtocolFactory protocol_factories = make(map[string]ProtocolFactory) protocol_factories["Binary"] = NewTBinaryProtocolFactoryDefault() protocol_factories["Compact"] = NewTCompactProtocolFactory() //protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design protocol_factories["JSON"] = NewTJSONProtocolFactory() var tests map[string]func(*testing.T, ProtocolFactory) (bool, error) tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error)) tests["Test 1"] = ProtocolTest1 tests["Test 2"] = ProtocolTest2 //tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests for name, pf := range protocol_factories { for test, f := range tests { if s, err := f(t, pf); !s || err != nil { t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err) } } } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/serializer_types_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING /* THE FOLLOWING THRIFT FILE WAS USED TO CREATE THIS enum MyTestEnum { FIRST = 1, SECOND = 2, THIRD = 3, FOURTH = 4, } struct MyTestStruct { 1: bool on, 2: byte b, 3: i16 int16, 4: i32 int32, 5: i64 int64, 6: double d, 7: string st, 8: binary bin, 9: map stringMap, 10: list stringList, 11: set stringSet, 12: MyTestEnum e, } */ import ( "fmt" ) // (needed to ensure safety because of naive import list construction.) var _ = ZERO var _ = fmt.Printf var GoUnusedProtection__ int type MyTestEnum int64 const ( MyTestEnum_FIRST MyTestEnum = 1 MyTestEnum_SECOND MyTestEnum = 2 MyTestEnum_THIRD MyTestEnum = 3 MyTestEnum_FOURTH MyTestEnum = 4 ) func (p MyTestEnum) String() string { switch p { case MyTestEnum_FIRST: return "FIRST" case MyTestEnum_SECOND: return "SECOND" case MyTestEnum_THIRD: return "THIRD" case MyTestEnum_FOURTH: return "FOURTH" } return "" } func MyTestEnumFromString(s string) (MyTestEnum, error) { switch s { case "FIRST": return MyTestEnum_FIRST, nil case "SECOND": return MyTestEnum_SECOND, nil case "THIRD": return MyTestEnum_THIRD, nil case "FOURTH": return MyTestEnum_FOURTH, nil } return MyTestEnum(0), fmt.Errorf("not a valid MyTestEnum string") } func MyTestEnumPtr(v MyTestEnum) *MyTestEnum { return &v } type MyTestStruct struct { On bool `thrift:"on,1" json:"on"` B int8 `thrift:"b,2" json:"b"` Int16 int16 `thrift:"int16,3" json:"int16"` Int32 int32 `thrift:"int32,4" json:"int32"` Int64 int64 `thrift:"int64,5" json:"int64"` D float64 `thrift:"d,6" json:"d"` St string `thrift:"st,7" json:"st"` Bin []byte `thrift:"bin,8" json:"bin"` StringMap map[string]string `thrift:"stringMap,9" json:"stringMap"` StringList []string `thrift:"stringList,10" json:"stringList"` StringSet map[string]struct{} `thrift:"stringSet,11" json:"stringSet"` E MyTestEnum `thrift:"e,12" json:"e"` } func NewMyTestStruct() *MyTestStruct { return &MyTestStruct{} } func (p *MyTestStruct) GetOn() bool { return p.On } func (p *MyTestStruct) GetB() int8 { return p.B } func (p *MyTestStruct) GetInt16() int16 { return p.Int16 } func (p *MyTestStruct) GetInt32() int32 { return p.Int32 } func (p *MyTestStruct) GetInt64() int64 { return p.Int64 } func (p *MyTestStruct) GetD() float64 { return p.D } func (p *MyTestStruct) GetSt() string { return p.St } func (p *MyTestStruct) GetBin() []byte { return p.Bin } func (p *MyTestStruct) GetStringMap() map[string]string { return p.StringMap } func (p *MyTestStruct) GetStringList() []string { return p.StringList } func (p *MyTestStruct) GetStringSet() map[string]struct{} { return p.StringSet } func (p *MyTestStruct) GetE() MyTestEnum { return p.E } func (p *MyTestStruct) Read(iprot TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == STOP { break } switch fieldId { case 1: if err := p.readField1(iprot); err != nil { return err } case 2: if err := p.readField2(iprot); err != nil { return err } case 3: if err := p.readField3(iprot); err != nil { return err } case 4: if err := p.readField4(iprot); err != nil { return err } case 5: if err := p.readField5(iprot); err != nil { return err } case 6: if err := p.readField6(iprot); err != nil { return err } case 7: if err := p.readField7(iprot); err != nil { return err } case 8: if err := p.readField8(iprot); err != nil { return err } case 9: if err := p.readField9(iprot); err != nil { return err } case 10: if err := p.readField10(iprot); err != nil { return err } case 11: if err := p.readField11(iprot); err != nil { return err } case 12: if err := p.readField12(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MyTestStruct) readField1(iprot TProtocol) error { if v, err := iprot.ReadBool(); err != nil { return PrependError("error reading field 1: ", err) } else { p.On = v } return nil } func (p *MyTestStruct) readField2(iprot TProtocol) error { if v, err := iprot.ReadByte(); err != nil { return PrependError("error reading field 2: ", err) } else { temp := int8(v) p.B = temp } return nil } func (p *MyTestStruct) readField3(iprot TProtocol) error { if v, err := iprot.ReadI16(); err != nil { return PrependError("error reading field 3: ", err) } else { p.Int16 = v } return nil } func (p *MyTestStruct) readField4(iprot TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return PrependError("error reading field 4: ", err) } else { p.Int32 = v } return nil } func (p *MyTestStruct) readField5(iprot TProtocol) error { if v, err := iprot.ReadI64(); err != nil { return PrependError("error reading field 5: ", err) } else { p.Int64 = v } return nil } func (p *MyTestStruct) readField6(iprot TProtocol) error { if v, err := iprot.ReadDouble(); err != nil { return PrependError("error reading field 6: ", err) } else { p.D = v } return nil } func (p *MyTestStruct) readField7(iprot TProtocol) error { if v, err := iprot.ReadString(); err != nil { return PrependError("error reading field 7: ", err) } else { p.St = v } return nil } func (p *MyTestStruct) readField8(iprot TProtocol) error { if v, err := iprot.ReadBinary(); err != nil { return PrependError("error reading field 8: ", err) } else { p.Bin = v } return nil } func (p *MyTestStruct) readField9(iprot TProtocol) error { _, _, size, err := iprot.ReadMapBegin() if err != nil { return PrependError("error reading map begin: ", err) } tMap := make(map[string]string, size) p.StringMap = tMap for i := 0; i < size; i++ { var _key0 string if v, err := iprot.ReadString(); err != nil { return PrependError("error reading field 0: ", err) } else { _key0 = v } var _val1 string if v, err := iprot.ReadString(); err != nil { return PrependError("error reading field 0: ", err) } else { _val1 = v } p.StringMap[_key0] = _val1 } if err := iprot.ReadMapEnd(); err != nil { return PrependError("error reading map end: ", err) } return nil } func (p *MyTestStruct) readField10(iprot TProtocol) error { _, size, err := iprot.ReadListBegin() if err != nil { return PrependError("error reading list begin: ", err) } tSlice := make([]string, 0, size) p.StringList = tSlice for i := 0; i < size; i++ { var _elem2 string if v, err := iprot.ReadString(); err != nil { return PrependError("error reading field 0: ", err) } else { _elem2 = v } p.StringList = append(p.StringList, _elem2) } if err := iprot.ReadListEnd(); err != nil { return PrependError("error reading list end: ", err) } return nil } func (p *MyTestStruct) readField11(iprot TProtocol) error { _, size, err := iprot.ReadSetBegin() if err != nil { return PrependError("error reading set begin: ", err) } tSet := make(map[string]struct{}, size) p.StringSet = tSet for i := 0; i < size; i++ { var _elem3 string if v, err := iprot.ReadString(); err != nil { return PrependError("error reading field 0: ", err) } else { _elem3 = v } p.StringSet[_elem3] = struct{}{} } if err := iprot.ReadSetEnd(); err != nil { return PrependError("error reading set end: ", err) } return nil } func (p *MyTestStruct) readField12(iprot TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return PrependError("error reading field 12: ", err) } else { temp := MyTestEnum(v) p.E = temp } return nil } func (p *MyTestStruct) Write(oprot TProtocol) error { if err := oprot.WriteStructBegin("MyTestStruct"); err != nil { return PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := p.writeField3(oprot); err != nil { return err } if err := p.writeField4(oprot); err != nil { return err } if err := p.writeField5(oprot); err != nil { return err } if err := p.writeField6(oprot); err != nil { return err } if err := p.writeField7(oprot); err != nil { return err } if err := p.writeField8(oprot); err != nil { return err } if err := p.writeField9(oprot); err != nil { return err } if err := p.writeField10(oprot); err != nil { return err } if err := p.writeField11(oprot); err != nil { return err } if err := p.writeField12(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return PrependError("write struct stop error: ", err) } return nil } func (p *MyTestStruct) writeField1(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 1:on: ", p), err) } if err := oprot.WriteBool(bool(p.On)); err != nil { return PrependError(fmt.Sprintf("%T.on (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 1:on: ", p), err) } return err } func (p *MyTestStruct) writeField2(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 2:b: ", p), err) } if err := oprot.WriteByte(int8(p.B)); err != nil { return PrependError(fmt.Sprintf("%T.b (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 2:b: ", p), err) } return err } func (p *MyTestStruct) writeField3(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 3:int16: ", p), err) } if err := oprot.WriteI16(int16(p.Int16)); err != nil { return PrependError(fmt.Sprintf("%T.int16 (3) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 3:int16: ", p), err) } return err } func (p *MyTestStruct) writeField4(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 4:int32: ", p), err) } if err := oprot.WriteI32(int32(p.Int32)); err != nil { return PrependError(fmt.Sprintf("%T.int32 (4) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 4:int32: ", p), err) } return err } func (p *MyTestStruct) writeField5(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 5:int64: ", p), err) } if err := oprot.WriteI64(int64(p.Int64)); err != nil { return PrependError(fmt.Sprintf("%T.int64 (5) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 5:int64: ", p), err) } return err } func (p *MyTestStruct) writeField6(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 6:d: ", p), err) } if err := oprot.WriteDouble(float64(p.D)); err != nil { return PrependError(fmt.Sprintf("%T.d (6) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 6:d: ", p), err) } return err } func (p *MyTestStruct) writeField7(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 7:st: ", p), err) } if err := oprot.WriteString(string(p.St)); err != nil { return PrependError(fmt.Sprintf("%T.st (7) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 7:st: ", p), err) } return err } func (p *MyTestStruct) writeField8(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 8:bin: ", p), err) } if err := oprot.WriteBinary(p.Bin); err != nil { return PrependError(fmt.Sprintf("%T.bin (8) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 8:bin: ", p), err) } return err } func (p *MyTestStruct) writeField9(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 9:stringMap: ", p), err) } if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil { return PrependError("error writing map begin: ", err) } for k, v := range p.StringMap { if err := oprot.WriteString(string(k)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } if err := oprot.WriteString(string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } if err := oprot.WriteMapEnd(); err != nil { return PrependError("error writing map end: ", err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 9:stringMap: ", p), err) } return err } func (p *MyTestStruct) writeField10(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 10:stringList: ", p), err) } if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil { return PrependError("error writing list begin: ", err) } for _, v := range p.StringList { if err := oprot.WriteString(string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } if err := oprot.WriteListEnd(); err != nil { return PrependError("error writing list end: ", err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 10:stringList: ", p), err) } return err } func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 11:stringSet: ", p), err) } if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil { return PrependError("error writing set begin: ", err) } for v, _ := range p.StringSet { if err := oprot.WriteString(string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } if err := oprot.WriteSetEnd(); err != nil { return PrependError("error writing set end: ", err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 11:stringSet: ", p), err) } return err } func (p *MyTestStruct) writeField12(oprot TProtocol) (err error) { if err := oprot.WriteFieldBegin("e", I32, 12); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 12:e: ", p), err) } if err := oprot.WriteI32(int32(p.E)); err != nil { return PrependError(fmt.Sprintf("%T.e (12) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return PrependError(fmt.Sprintf("%T write field end error 12:e: ", p), err) } return err } func (p *MyTestStruct) String() string { if p == nil { return "" } return fmt.Sprintf("MyTestStruct(%+v)", *p) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/server.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift type TServer interface { ProcessorFactory() TProcessorFactory ServerTransport() TServerTransport InputTransportFactory() TTransportFactory OutputTransportFactory() TTransportFactory InputProtocolFactory() TProtocolFactory OutputProtocolFactory() TProtocolFactory // Starts the server Serve() error // Stops the server. This is optional on a per-implementation basis. Not // all servers are required to be cleanly stoppable. Stop() error } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/server_socket.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "net" "sync" "time" ) type TServerSocket struct { listener net.Listener addr net.Addr clientTimeout time.Duration // Protects the interrupted value to make it thread safe. mu sync.RWMutex interrupted bool } func NewTServerSocket(listenAddr string) (*TServerSocket, error) { return NewTServerSocketTimeout(listenAddr, 0) } func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*TServerSocket, error) { addr, err := net.ResolveTCPAddr("tcp", listenAddr) if err != nil { return nil, err } return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil } func (p *TServerSocket) Listen() error { if p.IsListening() { return nil } l, err := net.Listen(p.addr.Network(), p.addr.String()) if err != nil { return err } p.listener = l return nil } func (p *TServerSocket) Accept() (TTransport, error) { p.mu.RLock() interrupted := p.interrupted p.mu.RUnlock() if interrupted { return nil, errTransportInterrupted } if p.listener == nil { return nil, NewTTransportException(NOT_OPEN, "No underlying server socket") } conn, err := p.listener.Accept() if err != nil { return nil, NewTTransportExceptionFromError(err) } return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil } // Checks whether the socket is listening. func (p *TServerSocket) IsListening() bool { return p.listener != nil } // Connects the socket, creating a new socket object if necessary. func (p *TServerSocket) Open() error { if p.IsListening() { return NewTTransportException(ALREADY_OPEN, "Server socket already open") } if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil { return err } else { p.listener = l } return nil } func (p *TServerSocket) Addr() net.Addr { if p.listener != nil { return p.listener.Addr() } return p.addr } func (p *TServerSocket) Close() error { defer func() { p.listener = nil }() if p.IsListening() { return p.listener.Close() } return nil } func (p *TServerSocket) Interrupt() error { p.mu.Lock() p.interrupted = true p.Close() p.mu.Unlock() return nil } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/server_socket_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "fmt" "testing" ) func TestSocketIsntListeningAfterInterrupt(t *testing.T) { host := "127.0.0.1" port := 9090 addr := fmt.Sprintf("%s:%d", host, port) socket := CreateServerSocket(t, addr) socket.Listen() socket.Interrupt() newSocket := CreateServerSocket(t, addr) err := newSocket.Listen() defer newSocket.Interrupt() if err != nil { t.Fatalf("Failed to rebinds: %s", err) } } func CreateServerSocket(t *testing.T, addr string) *TServerSocket { socket, err := NewTServerSocket(addr) if err != nil { t.Fatalf("Failed to create server socket: %s", err) } return socket } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/server_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "testing" ) func TestNothing(t *testing.T) { } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/server_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Server transport. Object which provides client transports. type TServerTransport interface { Listen() error Accept() (TTransport, error) Close() error // Optional method implementation. This signals to the server transport // that it should break out of any accept() or listen() that it is currently // blocked on. This method, if implemented, MUST be thread safe, as it may // be called from a different thread context than the other TServerTransport // methods. Interrupt() error } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/simple_json_protocol.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "bufio" "bytes" "encoding/base64" "encoding/json" "fmt" "io" "math" "strconv" ) type _ParseContext int const ( _CONTEXT_IN_TOPLEVEL _ParseContext = 1 _CONTEXT_IN_LIST_FIRST _ParseContext = 2 _CONTEXT_IN_LIST _ParseContext = 3 _CONTEXT_IN_OBJECT_FIRST _ParseContext = 4 _CONTEXT_IN_OBJECT_NEXT_KEY _ParseContext = 5 _CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6 ) func (p _ParseContext) String() string { switch p { case _CONTEXT_IN_TOPLEVEL: return "TOPLEVEL" case _CONTEXT_IN_LIST_FIRST: return "LIST-FIRST" case _CONTEXT_IN_LIST: return "LIST" case _CONTEXT_IN_OBJECT_FIRST: return "OBJECT-FIRST" case _CONTEXT_IN_OBJECT_NEXT_KEY: return "OBJECT-NEXT-KEY" case _CONTEXT_IN_OBJECT_NEXT_VALUE: return "OBJECT-NEXT-VALUE" } return "UNKNOWN-PARSE-CONTEXT" } // JSON protocol implementation for thrift. // // This protocol produces/consumes a simple output format // suitable for parsing by scripting languages. It should not be // confused with the full-featured TJSONProtocol. // type TSimpleJSONProtocol struct { trans TTransport parseContextStack []int dumpContext []int writer *bufio.Writer reader *bufio.Reader } // Constructor func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol { v := &TSimpleJSONProtocol{trans: t, writer: bufio.NewWriter(t), reader: bufio.NewReader(t), } v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL)) v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL)) return v } // Factory type TSimpleJSONProtocolFactory struct{} func (p *TSimpleJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol { return NewTSimpleJSONProtocol(trans) } func NewTSimpleJSONProtocolFactory() *TSimpleJSONProtocolFactory { return &TSimpleJSONProtocolFactory{} } var ( JSON_COMMA []byte JSON_COLON []byte JSON_LBRACE []byte JSON_RBRACE []byte JSON_LBRACKET []byte JSON_RBRACKET []byte JSON_QUOTE byte JSON_QUOTE_BYTES []byte JSON_NULL []byte JSON_TRUE []byte JSON_FALSE []byte JSON_INFINITY string JSON_NEGATIVE_INFINITY string JSON_NAN string JSON_INFINITY_BYTES []byte JSON_NEGATIVE_INFINITY_BYTES []byte JSON_NAN_BYTES []byte json_nonbase_map_elem_bytes []byte ) func init() { JSON_COMMA = []byte{','} JSON_COLON = []byte{':'} JSON_LBRACE = []byte{'{'} JSON_RBRACE = []byte{'}'} JSON_LBRACKET = []byte{'['} JSON_RBRACKET = []byte{']'} JSON_QUOTE = '"' JSON_QUOTE_BYTES = []byte{'"'} JSON_NULL = []byte{'n', 'u', 'l', 'l'} JSON_TRUE = []byte{'t', 'r', 'u', 'e'} JSON_FALSE = []byte{'f', 'a', 'l', 's', 'e'} JSON_INFINITY = "Infinity" JSON_NEGATIVE_INFINITY = "-Infinity" JSON_NAN = "NaN" JSON_INFINITY_BYTES = []byte{'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} JSON_NEGATIVE_INFINITY_BYTES = []byte{'-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'} JSON_NAN_BYTES = []byte{'N', 'a', 'N'} json_nonbase_map_elem_bytes = []byte{']', ',', '['} } func jsonQuote(s string) string { b, _ := json.Marshal(s) s1 := string(b) return s1 } func jsonUnquote(s string) (string, bool) { s1 := new(string) err := json.Unmarshal([]byte(s), s1) return *s1, err == nil } func mismatch(expected, actual string) error { return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual) } func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { p.resetContextStack() // THRIFT-3735 if e := p.OutputListBegin(); e != nil { return e } if e := p.WriteString(name); e != nil { return e } if e := p.WriteByte(int8(typeId)); e != nil { return e } if e := p.WriteI32(seqId); e != nil { return e } return nil } func (p *TSimpleJSONProtocol) WriteMessageEnd() error { return p.OutputListEnd() } func (p *TSimpleJSONProtocol) WriteStructBegin(name string) error { if e := p.OutputObjectBegin(); e != nil { return e } return nil } func (p *TSimpleJSONProtocol) WriteStructEnd() error { return p.OutputObjectEnd() } func (p *TSimpleJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { if e := p.WriteString(name); e != nil { return e } return nil } func (p *TSimpleJSONProtocol) WriteFieldEnd() error { //return p.OutputListEnd() return nil } func (p *TSimpleJSONProtocol) WriteFieldStop() error { return nil } func (p *TSimpleJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } if e := p.WriteByte(int8(keyType)); e != nil { return e } if e := p.WriteByte(int8(valueType)); e != nil { return e } return p.WriteI32(int32(size)) } func (p *TSimpleJSONProtocol) WriteMapEnd() error { return p.OutputListEnd() } func (p *TSimpleJSONProtocol) WriteListBegin(elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } func (p *TSimpleJSONProtocol) WriteListEnd() error { return p.OutputListEnd() } func (p *TSimpleJSONProtocol) WriteSetBegin(elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } func (p *TSimpleJSONProtocol) WriteSetEnd() error { return p.OutputListEnd() } func (p *TSimpleJSONProtocol) WriteBool(b bool) error { return p.OutputBool(b) } func (p *TSimpleJSONProtocol) WriteByte(b int8) error { return p.WriteI32(int32(b)) } func (p *TSimpleJSONProtocol) WriteI16(v int16) error { return p.WriteI32(int32(v)) } func (p *TSimpleJSONProtocol) WriteI32(v int32) error { return p.OutputI64(int64(v)) } func (p *TSimpleJSONProtocol) WriteI64(v int64) error { return p.OutputI64(int64(v)) } func (p *TSimpleJSONProtocol) WriteDouble(v float64) error { return p.OutputF64(v) } func (p *TSimpleJSONProtocol) WriteString(v string) error { return p.OutputString(v) } func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { // JSON library only takes in a string, // not an arbitrary byte array, to ensure bytes are transmitted // efficiently we must convert this into a valid JSON string // therefore we use base64 encoding to avoid excessive escaping/quoting if e := p.OutputPreValue(); e != nil { return e } if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } writer := base64.NewEncoder(base64.StdEncoding, p.writer) if _, e := writer.Write(v); e != nil { p.writer.Reset(p.trans) // THRIFT-3735 return NewTProtocolException(e) } if e := writer.Close(); e != nil { return NewTProtocolException(e) } if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } return p.OutputPostValue() } // Reading methods. func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { p.resetContextStack() // THRIFT-3735 if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, typeId, seqId, err } if name, err = p.ReadString(); err != nil { return name, typeId, seqId, err } bTypeId, err := p.ReadByte() typeId = TMessageType(bTypeId) if err != nil { return name, typeId, seqId, err } if seqId, err = p.ReadI32(); err != nil { return name, typeId, seqId, err } return name, typeId, seqId, nil } func (p *TSimpleJSONProtocol) ReadMessageEnd() error { return p.ParseListEnd() } func (p *TSimpleJSONProtocol) ReadStructBegin() (name string, err error) { _, err = p.ParseObjectStart() return "", err } func (p *TSimpleJSONProtocol) ReadStructEnd() error { return p.ParseObjectEnd() } func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { if err := p.ParsePreValue(); err != nil { return "", STOP, 0, err } b, _ := p.reader.Peek(1) if len(b) > 0 { switch b[0] { case JSON_RBRACE[0]: return "", STOP, 0, nil case JSON_QUOTE: p.reader.ReadByte() name, err := p.ParseStringBody() // simplejson is not meant to be read back into thrift // - see http://wiki.apache.org/thrift/ThriftUsageJava // - use JSON instead if err != nil { return name, STOP, 0, err } return name, STOP, -1, p.ParsePostValue() /* if err = p.ParsePostValue(); err != nil { return name, STOP, 0, err } if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, STOP, 0, err } bType, err := p.ReadByte() thetype := TType(bType) if err != nil { return name, thetype, 0, err } id, err := p.ReadI16() return name, thetype, id, err */ } e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b)) return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e) } return "", STOP, 0, NewTProtocolException(io.EOF) } func (p *TSimpleJSONProtocol) ReadFieldEnd() error { return nil //return p.ParseListEnd() } func (p *TSimpleJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, VOID, 0, e } // read keyType bKeyType, e := p.ReadByte() keyType = TType(bKeyType) if e != nil { return keyType, valueType, size, e } // read valueType bValueType, e := p.ReadByte() valueType = TType(bValueType) if e != nil { return keyType, valueType, size, e } // read size iSize, err := p.ReadI64() size = int(iSize) return keyType, valueType, size, err } func (p *TSimpleJSONProtocol) ReadMapEnd() error { return p.ParseListEnd() } func (p *TSimpleJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { return p.ParseElemListBegin() } func (p *TSimpleJSONProtocol) ReadListEnd() error { return p.ParseListEnd() } func (p *TSimpleJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { return p.ParseElemListBegin() } func (p *TSimpleJSONProtocol) ReadSetEnd() error { return p.ParseListEnd() } func (p *TSimpleJSONProtocol) ReadBool() (bool, error) { var value bool if err := p.ParsePreValue(); err != nil { return value, err } f, _ := p.reader.Peek(1) if len(f) > 0 { switch f[0] { case JSON_TRUE[0]: b := make([]byte, len(JSON_TRUE)) _, err := p.reader.Read(b) if err != nil { return false, NewTProtocolException(err) } if string(b) == string(JSON_TRUE) { value = true } else { e := fmt.Errorf("Expected \"true\" but found: %s", string(b)) return value, NewTProtocolExceptionWithType(INVALID_DATA, e) } break case JSON_FALSE[0]: b := make([]byte, len(JSON_FALSE)) _, err := p.reader.Read(b) if err != nil { return false, NewTProtocolException(err) } if string(b) == string(JSON_FALSE) { value = false } else { e := fmt.Errorf("Expected \"false\" but found: %s", string(b)) return value, NewTProtocolExceptionWithType(INVALID_DATA, e) } break case JSON_NULL[0]: b := make([]byte, len(JSON_NULL)) _, err := p.reader.Read(b) if err != nil { return false, NewTProtocolException(err) } if string(b) == string(JSON_NULL) { value = false } else { e := fmt.Errorf("Expected \"null\" but found: %s", string(b)) return value, NewTProtocolExceptionWithType(INVALID_DATA, e) } default: e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(f)) return value, NewTProtocolExceptionWithType(INVALID_DATA, e) } } return value, p.ParsePostValue() } func (p *TSimpleJSONProtocol) ReadByte() (int8, error) { v, err := p.ReadI64() return int8(v), err } func (p *TSimpleJSONProtocol) ReadI16() (int16, error) { v, err := p.ReadI64() return int16(v), err } func (p *TSimpleJSONProtocol) ReadI32() (int32, error) { v, err := p.ReadI64() return int32(v), err } func (p *TSimpleJSONProtocol) ReadI64() (int64, error) { v, _, err := p.ParseI64() return v, err } func (p *TSimpleJSONProtocol) ReadDouble() (float64, error) { v, _, err := p.ParseF64() return v, err } func (p *TSimpleJSONProtocol) ReadString() (string, error) { var v string if err := p.ParsePreValue(); err != nil { return v, err } f, _ := p.reader.Peek(1) if len(f) > 0 && f[0] == JSON_QUOTE { p.reader.ReadByte() value, err := p.ParseStringBody() v = value if err != nil { return v, err } } else if len(f) > 0 && f[0] == JSON_NULL[0] { b := make([]byte, len(JSON_NULL)) _, err := p.reader.Read(b) if err != nil { return v, NewTProtocolException(err) } if string(b) != string(JSON_NULL) { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } } else { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } return v, p.ParsePostValue() } func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) { var v []byte if err := p.ParsePreValue(); err != nil { return nil, err } f, _ := p.reader.Peek(1) if len(f) > 0 && f[0] == JSON_QUOTE { p.reader.ReadByte() value, err := p.ParseBase64EncodedBody() v = value if err != nil { return v, err } } else if len(f) > 0 && f[0] == JSON_NULL[0] { b := make([]byte, len(JSON_NULL)) _, err := p.reader.Read(b) if err != nil { return v, NewTProtocolException(err) } if string(b) != string(JSON_NULL) { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } } else { e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f)) return v, NewTProtocolExceptionWithType(INVALID_DATA, e) } return v, p.ParsePostValue() } func (p *TSimpleJSONProtocol) Flush() (err error) { return NewTProtocolException(p.writer.Flush()) } func (p *TSimpleJSONProtocol) Skip(fieldType TType) (err error) { return SkipDefaultDepth(p, fieldType) } func (p *TSimpleJSONProtocol) Transport() TTransport { return p.trans } func (p *TSimpleJSONProtocol) OutputPreValue() error { cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) switch cxt { case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY: if _, e := p.write(JSON_COMMA); e != nil { return NewTProtocolException(e) } break case _CONTEXT_IN_OBJECT_NEXT_VALUE: if _, e := p.write(JSON_COLON); e != nil { return NewTProtocolException(e) } break } return nil } func (p *TSimpleJSONProtocol) OutputPostValue() error { cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) switch cxt { case _CONTEXT_IN_LIST_FIRST: p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST)) break case _CONTEXT_IN_OBJECT_FIRST: p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) break case _CONTEXT_IN_OBJECT_NEXT_KEY: p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) break case _CONTEXT_IN_OBJECT_NEXT_VALUE: p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY)) break } return nil } func (p *TSimpleJSONProtocol) OutputBool(value bool) error { if e := p.OutputPreValue(); e != nil { return e } var v string if value { v = string(JSON_TRUE) } else { v = string(JSON_FALSE) } switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: v = jsonQuote(v) default: } if e := p.OutputStringData(v); e != nil { return e } return p.OutputPostValue() } func (p *TSimpleJSONProtocol) OutputNull() error { if e := p.OutputPreValue(); e != nil { return e } if _, e := p.write(JSON_NULL); e != nil { return NewTProtocolException(e) } return p.OutputPostValue() } func (p *TSimpleJSONProtocol) OutputF64(value float64) error { if e := p.OutputPreValue(); e != nil { return e } var v string if math.IsNaN(value) { v = string(JSON_QUOTE) + JSON_NAN + string(JSON_QUOTE) } else if math.IsInf(value, 1) { v = string(JSON_QUOTE) + JSON_INFINITY + string(JSON_QUOTE) } else if math.IsInf(value, -1) { v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE) } else { v = strconv.FormatFloat(value, 'g', -1, 64) switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: v = string(JSON_QUOTE) + v + string(JSON_QUOTE) default: } } if e := p.OutputStringData(v); e != nil { return e } return p.OutputPostValue() } func (p *TSimpleJSONProtocol) OutputI64(value int64) error { if e := p.OutputPreValue(); e != nil { return e } v := strconv.FormatInt(value, 10) switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) { case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: v = jsonQuote(v) default: } if e := p.OutputStringData(v); e != nil { return e } return p.OutputPostValue() } func (p *TSimpleJSONProtocol) OutputString(s string) error { if e := p.OutputPreValue(); e != nil { return e } if e := p.OutputStringData(jsonQuote(s)); e != nil { return e } return p.OutputPostValue() } func (p *TSimpleJSONProtocol) OutputStringData(s string) error { _, e := p.write([]byte(s)) return NewTProtocolException(e) } func (p *TSimpleJSONProtocol) OutputObjectBegin() error { if e := p.OutputPreValue(); e != nil { return e } if _, e := p.write(JSON_LBRACE); e != nil { return NewTProtocolException(e) } p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST)) return nil } func (p *TSimpleJSONProtocol) OutputObjectEnd() error { if _, e := p.write(JSON_RBRACE); e != nil { return NewTProtocolException(e) } p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] if e := p.OutputPostValue(); e != nil { return e } return nil } func (p *TSimpleJSONProtocol) OutputListBegin() error { if e := p.OutputPreValue(); e != nil { return e } if _, e := p.write(JSON_LBRACKET); e != nil { return NewTProtocolException(e) } p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST)) return nil } func (p *TSimpleJSONProtocol) OutputListEnd() error { if _, e := p.write(JSON_RBRACKET); e != nil { return NewTProtocolException(e) } p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] if e := p.OutputPostValue(); e != nil { return e } return nil } func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } if e := p.WriteByte(int8(elemType)); e != nil { return e } if e := p.WriteI64(int64(size)); e != nil { return e } return nil } func (p *TSimpleJSONProtocol) ParsePreValue() error { if e := p.readNonSignificantWhitespace(); e != nil { return NewTProtocolException(e) } cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) b, _ := p.reader.Peek(1) switch cxt { case _CONTEXT_IN_LIST: if len(b) > 0 { switch b[0] { case JSON_RBRACKET[0]: return nil case JSON_COMMA[0]: p.reader.ReadByte() if e := p.readNonSignificantWhitespace(); e != nil { return NewTProtocolException(e) } return nil default: e := fmt.Errorf("Expected \"]\" or \",\" in list context, but found \"%s\"", string(b)) return NewTProtocolExceptionWithType(INVALID_DATA, e) } } break case _CONTEXT_IN_OBJECT_NEXT_KEY: if len(b) > 0 { switch b[0] { case JSON_RBRACE[0]: return nil case JSON_COMMA[0]: p.reader.ReadByte() if e := p.readNonSignificantWhitespace(); e != nil { return NewTProtocolException(e) } return nil default: e := fmt.Errorf("Expected \"}\" or \",\" in object context, but found \"%s\"", string(b)) return NewTProtocolExceptionWithType(INVALID_DATA, e) } } break case _CONTEXT_IN_OBJECT_NEXT_VALUE: if len(b) > 0 { switch b[0] { case JSON_COLON[0]: p.reader.ReadByte() if e := p.readNonSignificantWhitespace(); e != nil { return NewTProtocolException(e) } return nil default: e := fmt.Errorf("Expected \":\" in object context, but found \"%s\"", string(b)) return NewTProtocolExceptionWithType(INVALID_DATA, e) } } break } return nil } func (p *TSimpleJSONProtocol) ParsePostValue() error { if e := p.readNonSignificantWhitespace(); e != nil { return NewTProtocolException(e) } cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) switch cxt { case _CONTEXT_IN_LIST_FIRST: p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST)) break case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY: p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE)) break case _CONTEXT_IN_OBJECT_NEXT_VALUE: p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY)) break } return nil } func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error { for { b, _ := p.reader.Peek(1) if len(b) < 1 { return nil } switch b[0] { case ' ', '\r', '\n', '\t': p.reader.ReadByte() continue default: break } break } return nil } func (p *TSimpleJSONProtocol) ParseStringBody() (string, error) { line, err := p.reader.ReadString(JSON_QUOTE) if err != nil { return "", NewTProtocolException(err) } l := len(line) // count number of escapes to see if we need to keep going i := 1 for ; i < l; i++ { if line[l-i-1] != '\\' { break } } if i&0x01 == 1 { v, ok := jsonUnquote(string(JSON_QUOTE) + line) if !ok { return "", NewTProtocolException(err) } return v, nil } s, err := p.ParseQuotedStringBody() if err != nil { return "", NewTProtocolException(err) } str := string(JSON_QUOTE) + line + s v, ok := jsonUnquote(str) if !ok { e := fmt.Errorf("Unable to parse as JSON string %s", str) return "", NewTProtocolExceptionWithType(INVALID_DATA, e) } return v, nil } func (p *TSimpleJSONProtocol) ParseQuotedStringBody() (string, error) { line, err := p.reader.ReadString(JSON_QUOTE) if err != nil { return "", NewTProtocolException(err) } l := len(line) // count number of escapes to see if we need to keep going i := 1 for ; i < l; i++ { if line[l-i-1] != '\\' { break } } if i&0x01 == 1 { return line, nil } s, err := p.ParseQuotedStringBody() if err != nil { return "", NewTProtocolException(err) } v := line + s return v, nil } func (p *TSimpleJSONProtocol) ParseBase64EncodedBody() ([]byte, error) { line, err := p.reader.ReadBytes(JSON_QUOTE) if err != nil { return line, NewTProtocolException(err) } line2 := line[0 : len(line)-1] l := len(line2) if (l % 4) != 0 { pad := 4 - (l % 4) fill := [...]byte{'=', '=', '='} line2 = append(line2, fill[:pad]...) l = len(line2) } output := make([]byte, base64.StdEncoding.DecodedLen(l)) n, err := base64.StdEncoding.Decode(output, line2) return output[0:n], NewTProtocolException(err) } func (p *TSimpleJSONProtocol) ParseI64() (int64, bool, error) { if err := p.ParsePreValue(); err != nil { return 0, false, err } var value int64 var isnull bool if p.safePeekContains(JSON_NULL) { p.reader.Read(make([]byte, len(JSON_NULL))) isnull = true } else { num, err := p.readNumeric() isnull = (num == nil) if !isnull { value = num.Int64() } if err != nil { return value, isnull, err } } return value, isnull, p.ParsePostValue() } func (p *TSimpleJSONProtocol) ParseF64() (float64, bool, error) { if err := p.ParsePreValue(); err != nil { return 0, false, err } var value float64 var isnull bool if p.safePeekContains(JSON_NULL) { p.reader.Read(make([]byte, len(JSON_NULL))) isnull = true } else { num, err := p.readNumeric() isnull = (num == nil) if !isnull { value = num.Float64() } if err != nil { return value, isnull, err } } return value, isnull, p.ParsePostValue() } func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) { if err := p.ParsePreValue(); err != nil { return false, err } var b []byte b, err := p.reader.Peek(1) if err != nil { return false, err } if len(b) > 0 && b[0] == JSON_LBRACE[0] { p.reader.ReadByte() p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST)) return false, nil } else if p.safePeekContains(JSON_NULL) { return true, nil } e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b)) return false, NewTProtocolExceptionWithType(INVALID_DATA, e) } func (p *TSimpleJSONProtocol) ParseObjectEnd() error { if isNull, err := p.readIfNull(); isNull || err != nil { return err } cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) { e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt) return NewTProtocolExceptionWithType(INVALID_DATA, e) } line, err := p.reader.ReadString(JSON_RBRACE[0]) if err != nil { return NewTProtocolException(err) } for _, char := range line { switch char { default: e := fmt.Errorf("Expecting end of object \"}\", but found: \"%s\"", line) return NewTProtocolExceptionWithType(INVALID_DATA, e) case ' ', '\n', '\r', '\t', '}': break } } p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] return p.ParsePostValue() } func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) { if e := p.ParsePreValue(); e != nil { return false, e } var b []byte b, err = p.reader.Peek(1) if err != nil { return false, err } if len(b) >= 1 && b[0] == JSON_LBRACKET[0] { p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST)) p.reader.ReadByte() isNull = false } else if p.safePeekContains(JSON_NULL) { isNull = true } else { err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b) } return isNull, NewTProtocolExceptionWithType(INVALID_DATA, err) } func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } bElemType, err := p.ReadByte() elemType = TType(bElemType) if err != nil { return elemType, size, err } nSize, err2 := p.ReadI64() size = int(nSize) return elemType, size, err2 } func (p *TSimpleJSONProtocol) ParseListEnd() error { if isNull, err := p.readIfNull(); isNull || err != nil { return err } cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) if cxt != _CONTEXT_IN_LIST { e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt) return NewTProtocolExceptionWithType(INVALID_DATA, e) } line, err := p.reader.ReadString(JSON_RBRACKET[0]) if err != nil { return NewTProtocolException(err) } for _, char := range line { switch char { default: e := fmt.Errorf("Expecting end of list \"]\", but found: \"", line, "\"") return NewTProtocolExceptionWithType(INVALID_DATA, e) case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]): break } } p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1] if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL { return nil } return p.ParsePostValue() } func (p *TSimpleJSONProtocol) readSingleValue() (interface{}, TType, error) { e := p.readNonSignificantWhitespace() if e != nil { return nil, VOID, NewTProtocolException(e) } b, e := p.reader.Peek(1) if len(b) > 0 { c := b[0] switch c { case JSON_NULL[0]: buf := make([]byte, len(JSON_NULL)) _, e := p.reader.Read(buf) if e != nil { return nil, VOID, NewTProtocolException(e) } if string(JSON_NULL) != string(buf) { e = mismatch(string(JSON_NULL), string(buf)) return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) } return nil, VOID, nil case JSON_QUOTE: p.reader.ReadByte() v, e := p.ParseStringBody() if e != nil { return v, UTF8, NewTProtocolException(e) } if v == JSON_INFINITY { return INFINITY, DOUBLE, nil } else if v == JSON_NEGATIVE_INFINITY { return NEGATIVE_INFINITY, DOUBLE, nil } else if v == JSON_NAN { return NAN, DOUBLE, nil } return v, UTF8, nil case JSON_TRUE[0]: buf := make([]byte, len(JSON_TRUE)) _, e := p.reader.Read(buf) if e != nil { return true, BOOL, NewTProtocolException(e) } if string(JSON_TRUE) != string(buf) { e := mismatch(string(JSON_TRUE), string(buf)) return true, BOOL, NewTProtocolExceptionWithType(INVALID_DATA, e) } return true, BOOL, nil case JSON_FALSE[0]: buf := make([]byte, len(JSON_FALSE)) _, e := p.reader.Read(buf) if e != nil { return false, BOOL, NewTProtocolException(e) } if string(JSON_FALSE) != string(buf) { e := mismatch(string(JSON_FALSE), string(buf)) return false, BOOL, NewTProtocolExceptionWithType(INVALID_DATA, e) } return false, BOOL, nil case JSON_LBRACKET[0]: _, e := p.reader.ReadByte() return make([]interface{}, 0), LIST, NewTProtocolException(e) case JSON_LBRACE[0]: _, e := p.reader.ReadByte() return make(map[string]interface{}), STRUCT, NewTProtocolException(e) case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'e', 'E', '.', '+', '-', JSON_INFINITY[0], JSON_NAN[0]: // assume numeric v, e := p.readNumeric() return v, DOUBLE, e default: e := fmt.Errorf("Expected element in list but found '%s' while parsing JSON.", string(c)) return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) } } e = fmt.Errorf("Cannot read a single element while parsing JSON.") return nil, VOID, NewTProtocolExceptionWithType(INVALID_DATA, e) } func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { cont := true for cont { b, _ := p.reader.Peek(1) if len(b) < 1 { return false, nil } switch b[0] { default: return false, nil case JSON_NULL[0]: cont = false break case ' ', '\n', '\r', '\t': p.reader.ReadByte() break } } if p.safePeekContains(JSON_NULL) { p.reader.Read(make([]byte, len(JSON_NULL))) return true, nil } return false, nil } func (p *TSimpleJSONProtocol) readQuoteIfNext() { b, _ := p.reader.Peek(1) if len(b) > 0 && b[0] == JSON_QUOTE { p.reader.ReadByte() } } func (p *TSimpleJSONProtocol) readNumeric() (Numeric, error) { isNull, err := p.readIfNull() if isNull || err != nil { return NUMERIC_NULL, err } hasDecimalPoint := false nextCanBeSign := true hasE := false MAX_LEN := 40 buf := bytes.NewBuffer(make([]byte, 0, MAX_LEN)) continueFor := true inQuotes := false for continueFor { c, err := p.reader.ReadByte() if err != nil { if err == io.EOF { break } return NUMERIC_NULL, NewTProtocolException(err) } switch c { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': buf.WriteByte(c) nextCanBeSign = false case '.': if hasDecimalPoint { e := fmt.Errorf("Unable to parse number with multiple decimal points '%s.'", buf.String()) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } if hasE { e := fmt.Errorf("Unable to parse number with decimal points in the exponent '%s.'", buf.String()) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } buf.WriteByte(c) hasDecimalPoint, nextCanBeSign = true, false case 'e', 'E': if hasE { e := fmt.Errorf("Unable to parse number with multiple exponents '%s%c'", buf.String(), c) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } buf.WriteByte(c) hasE, nextCanBeSign = true, true case '-', '+': if !nextCanBeSign { e := fmt.Errorf("Negative sign within number") return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } buf.WriteByte(c) nextCanBeSign = false case ' ', 0, '\t', '\n', '\r', JSON_RBRACE[0], JSON_RBRACKET[0], JSON_COMMA[0], JSON_COLON[0]: p.reader.UnreadByte() continueFor = false case JSON_NAN[0]: if buf.Len() == 0 { buffer := make([]byte, len(JSON_NAN)) buffer[0] = c _, e := p.reader.Read(buffer[1:]) if e != nil { return NUMERIC_NULL, NewTProtocolException(e) } if JSON_NAN != string(buffer) { e := mismatch(JSON_NAN, string(buffer)) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } if inQuotes { p.readQuoteIfNext() } return NAN, nil } else { e := fmt.Errorf("Unable to parse number starting with character '%c'", c) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } case JSON_INFINITY[0]: if buf.Len() == 0 || (buf.Len() == 1 && buf.Bytes()[0] == '+') { buffer := make([]byte, len(JSON_INFINITY)) buffer[0] = c _, e := p.reader.Read(buffer[1:]) if e != nil { return NUMERIC_NULL, NewTProtocolException(e) } if JSON_INFINITY != string(buffer) { e := mismatch(JSON_INFINITY, string(buffer)) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } if inQuotes { p.readQuoteIfNext() } return INFINITY, nil } else if buf.Len() == 1 && buf.Bytes()[0] == JSON_NEGATIVE_INFINITY[0] { buffer := make([]byte, len(JSON_NEGATIVE_INFINITY)) buffer[0] = JSON_NEGATIVE_INFINITY[0] buffer[1] = c _, e := p.reader.Read(buffer[2:]) if e != nil { return NUMERIC_NULL, NewTProtocolException(e) } if JSON_NEGATIVE_INFINITY != string(buffer) { e := mismatch(JSON_NEGATIVE_INFINITY, string(buffer)) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } if inQuotes { p.readQuoteIfNext() } return NEGATIVE_INFINITY, nil } else { e := fmt.Errorf("Unable to parse number starting with character '%c' due to existing buffer %s", c, buf.String()) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } case JSON_QUOTE: if !inQuotes { inQuotes = true } else { break } default: e := fmt.Errorf("Unable to parse number starting with character '%c'", c) return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } } if buf.Len() == 0 { e := fmt.Errorf("Unable to parse number from empty string ''") return NUMERIC_NULL, NewTProtocolExceptionWithType(INVALID_DATA, e) } return NewNumericFromJSONString(buf.String(), false), nil } // Safely peeks into the buffer, reading only what is necessary func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool { for i := 0; i < len(b); i++ { a, _ := p.reader.Peek(i + 1) if len(a) == 0 || a[i] != b[i] { return false } } return true } // Reset the context stack to its initial state. func (p *TSimpleJSONProtocol) resetContextStack() { p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)} p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)} } func (p *TSimpleJSONProtocol) write(b []byte) (int, error) { n, err := p.writer.Write(b) if err != nil { p.writer.Reset(p.trans) // THRIFT-3735 } return n, err } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/simple_json_protocol_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "encoding/base64" "encoding/json" "fmt" "math" "strconv" "strings" "testing" ) func TestWriteSimpleJSONProtocolBool(t *testing.T) { thetype := "boolean" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range BOOL_VALUES { if e := p.WriteBool(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := false if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolBool(t *testing.T) { thetype := "boolean" for _, value := range BOOL_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) if value { trans.Write(JSON_TRUE) } else { trans.Write(JSON_FALSE) } trans.Flush() s := trans.String() v, e := p.ReadBool() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteSimpleJSONProtocolByte(t *testing.T) { thetype := "byte" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range BYTE_VALUES { if e := p.WriteByte(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int8(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolByte(t *testing.T) { thetype := "byte" for _, value := range BYTE_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(strconv.Itoa(int(value))) trans.Flush() s := trans.String() v, e := p.ReadByte() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteSimpleJSONProtocolI16(t *testing.T) { thetype := "int16" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range INT16_VALUES { if e := p.WriteI16(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int16(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolI16(t *testing.T) { thetype := "int16" for _, value := range INT16_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(strconv.Itoa(int(value))) trans.Flush() s := trans.String() v, e := p.ReadI16() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestWriteSimpleJSONProtocolI32(t *testing.T) { thetype := "int32" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range INT32_VALUES { if e := p.WriteI32(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int32(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolI32(t *testing.T) { thetype := "int32" for _, value := range INT32_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(strconv.Itoa(int(value))) trans.Flush() s := trans.String() v, e := p.ReadI32() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestReadSimpleJSONProtocolI32Null(t *testing.T) { thetype := "int32" value := "null" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(value) trans.Flush() s := trans.String() v, e := p.ReadI32() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != 0 { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } trans.Reset() trans.Close() } func TestWriteSimpleJSONProtocolI64(t *testing.T) { thetype := "int64" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range INT64_VALUES { if e := p.WriteI64(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := int64(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolI64(t *testing.T) { thetype := "int64" for _, value := range INT64_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(strconv.FormatInt(value, 10)) trans.Flush() s := trans.String() v, e := p.ReadI64() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } trans.Reset() trans.Close() } } func TestReadSimpleJSONProtocolI64Null(t *testing.T) { thetype := "int32" value := "null" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(value) trans.Flush() s := trans.String() v, e := p.ReadI64() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != 0 { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } trans.Reset() trans.Close() } func TestWriteSimpleJSONProtocolDouble(t *testing.T) { thetype := "double" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range DOUBLE_VALUES { if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if math.IsInf(value, 1) { if s != jsonQuote(JSON_INFINITY) { t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY)) } } else if math.IsInf(value, -1) { if s != jsonQuote(JSON_NEGATIVE_INFINITY) { t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) } } else if math.IsNaN(value) { if s != jsonQuote(JSON_NAN) { t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN)) } } else { if s != fmt.Sprint(value) { t.Fatalf("Bad value for %s %v: %s", thetype, value, s) } v := float64(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolDouble(t *testing.T) { thetype := "double" for _, value := range DOUBLE_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) n := NewNumericFromDouble(value) trans.WriteString(n.String()) trans.Flush() s := trans.String() v, e := p.ReadDouble() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if math.IsInf(value, 1) { if !math.IsInf(v, 1) { t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) } } else if math.IsInf(value, -1) { if !math.IsInf(v, -1) { t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) } } else if math.IsNaN(value) { if !math.IsNaN(v) { t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v) } } else { if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } } trans.Reset() trans.Close() } } func TestWriteSimpleJSONProtocolString(t *testing.T) { thetype := "string" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range STRING_VALUES { if e := p.WriteString(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s[0] != '"' || s[len(s)-1] != '"' { t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\"")) } v := new(string) if err := json.Unmarshal([]byte(s), v); err != nil || *v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v) } trans.Reset() } trans.Close() } func TestReadSimpleJSONProtocolString(t *testing.T) { thetype := "string" for _, value := range STRING_VALUES { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(jsonQuote(value)) trans.Flush() s := trans.String() v, e := p.ReadString() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != value { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } v1 := new(string) if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) } trans.Reset() trans.Close() } } func TestReadSimpleJSONProtocolStringNull(t *testing.T) { thetype := "string" value := "null" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(value) trans.Flush() s := trans.String() v, e := p.ReadString() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != "" { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } trans.Reset() trans.Close() } func TestWriteSimpleJSONProtocolBinary(t *testing.T) { thetype := "binary" value := protocol_bdata b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) base64.StdEncoding.Encode(b64value, value) b64String := string(b64value) trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) if e := p.WriteBinary(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error()) } s := trans.String() if s != fmt.Sprint("\"", b64String, "\"") { t.Fatalf("Bad value for %s %v\n wrote: %v\nexpected: %v", thetype, value, s, "\""+b64String+"\"") } v1 := new(string) if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) } trans.Close() } func TestReadSimpleJSONProtocolBinary(t *testing.T) { thetype := "binary" value := protocol_bdata b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata))) base64.StdEncoding.Encode(b64value, value) b64String := string(b64value) trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(jsonQuote(b64String)) trans.Flush() s := trans.String() v, e := p.ReadBinary() if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if len(v) != len(value) { t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v)) } for i := 0; i < len(v); i++ { if v[i] != value[i] { t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i]) } } v1 := new(string) if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1) } trans.Reset() trans.Close() } func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) { thetype := "binary" value := "null" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) trans.WriteString(value) trans.Flush() s := trans.String() b, e := p.ReadBinary() v := string(b) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } if v != "" { t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v) } trans.Reset() trans.Close() } func TestWriteSimpleJSONProtocolList(t *testing.T) { thetype := "list" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } p.WriteListEnd() if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } str := trans.String() str1 := new([]interface{}) err := json.Unmarshal([]byte(str), str1) if err != nil { t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) } l := *str1 if len(l) < 2 { t.Fatalf("List must be at least of length two to include metadata") } if int(l[0].(float64)) != DOUBLE { t.Fatal("Invalid type for list, expected: ", DOUBLE, ", but was: ", l[0]) } if int(l[1].(float64)) != len(DOUBLE_VALUES) { t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) } for k, value := range DOUBLE_VALUES { s := l[k+2] if math.IsInf(value, 1) { if s.(string) != JSON_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) } } else if math.IsInf(value, 0) { if s.(string) != JSON_NEGATIVE_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) } } else if math.IsNaN(value) { if s.(string) != JSON_NAN { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) } } else { if s.(float64) != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) } } trans.Reset() } trans.Close() } func TestWriteSimpleJSONProtocolSet(t *testing.T) { thetype := "set" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } p.WriteSetEnd() if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } str := trans.String() str1 := new([]interface{}) err := json.Unmarshal([]byte(str), str1) if err != nil { t.Fatalf("Unable to decode %s, wrote: %s", thetype, str) } l := *str1 if len(l) < 2 { t.Fatalf("Set must be at least of length two to include metadata") } if int(l[0].(float64)) != DOUBLE { t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0]) } if int(l[1].(float64)) != len(DOUBLE_VALUES) { t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1]) } for k, value := range DOUBLE_VALUES { s := l[k+2] if math.IsInf(value, 1) { if s.(string) != JSON_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str) } } else if math.IsInf(value, 0) { if s.(string) != JSON_NEGATIVE_INFINITY { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str) } } else if math.IsNaN(value) { if s.(string) != JSON_NAN { t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str) } } else { if s.(float64) != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s) } } trans.Reset() } trans.Close() } func TestWriteSimpleJSONProtocolMap(t *testing.T) { thetype := "map" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) for k, value := range DOUBLE_VALUES { if e := p.WriteI32(int32(k)); e != nil { t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) } if e := p.WriteDouble(value); e != nil { t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) } } p.WriteMapEnd() if e := p.Flush(); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } str := trans.String() if str[0] != '[' || str[len(str)-1] != ']' { t.Fatalf("Bad value for %s, wrote: %q, in go: %q", thetype, str, DOUBLE_VALUES) } l := strings.Split(str[1:len(str)-1], ",") if len(l) < 3 { t.Fatal("Expected list of at least length 3 for map for metadata, but was of length ", len(l)) } expectedKeyType, _ := strconv.Atoi(l[0]) expectedValueType, _ := strconv.Atoi(l[1]) expectedSize, _ := strconv.Atoi(l[2]) if expectedKeyType != I32 { t.Fatal("Expected map key type ", I32, ", but was ", l[0]) } if expectedValueType != DOUBLE { t.Fatal("Expected map value type ", DOUBLE, ", but was ", l[1]) } if expectedSize != len(DOUBLE_VALUES) { t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", l[2]) } for k, value := range DOUBLE_VALUES { strk := l[k*2+3] strv := l[k*2+4] ik, err := strconv.Atoi(strk) if err != nil { t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, strk, string(k), err.Error()) } if ik != k { t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v", thetype, k, strk, k) } s := strv if math.IsInf(value, 1) { if s != jsonQuote(JSON_INFINITY) { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY)) } } else if math.IsInf(value, 0) { if s != jsonQuote(JSON_NEGATIVE_INFINITY) { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY)) } } else if math.IsNaN(value) { if s != jsonQuote(JSON_NAN) { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN)) } } else { expected := strconv.FormatFloat(value, 'g', 10, 64) if s != expected { t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected) } v := float64(0) if err := json.Unmarshal([]byte(s), &v); err != nil || v != value { t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v) } } trans.Reset() } trans.Close() } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/simple_server.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "log" "runtime/debug" "sync" ) // Simple, non-concurrent server for testing. type TSimpleServer struct { quit chan struct{} processorFactory TProcessorFactory serverTransport TServerTransport inputTransportFactory TTransportFactory outputTransportFactory TTransportFactory inputProtocolFactory TProtocolFactory outputProtocolFactory TProtocolFactory } func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport) } func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { return NewTSimpleServerFactory4(NewTProcessorFactory(processor), serverTransport, transportFactory, protocolFactory, ) } func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { return NewTSimpleServerFactory6(NewTProcessorFactory(processor), serverTransport, inputTransportFactory, outputTransportFactory, inputProtocolFactory, outputProtocolFactory, ) } func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer { return NewTSimpleServerFactory6(processorFactory, serverTransport, NewTTransportFactory(), NewTTransportFactory(), NewTBinaryProtocolFactoryDefault(), NewTBinaryProtocolFactoryDefault(), ) } func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer { return NewTSimpleServerFactory6(processorFactory, serverTransport, transportFactory, transportFactory, protocolFactory, protocolFactory, ) } func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer { return &TSimpleServer{ processorFactory: processorFactory, serverTransport: serverTransport, inputTransportFactory: inputTransportFactory, outputTransportFactory: outputTransportFactory, inputProtocolFactory: inputProtocolFactory, outputProtocolFactory: outputProtocolFactory, quit: make(chan struct{}, 1), } } func (p *TSimpleServer) ProcessorFactory() TProcessorFactory { return p.processorFactory } func (p *TSimpleServer) ServerTransport() TServerTransport { return p.serverTransport } func (p *TSimpleServer) InputTransportFactory() TTransportFactory { return p.inputTransportFactory } func (p *TSimpleServer) OutputTransportFactory() TTransportFactory { return p.outputTransportFactory } func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory { return p.inputProtocolFactory } func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory { return p.outputProtocolFactory } func (p *TSimpleServer) Listen() error { return p.serverTransport.Listen() } func (p *TSimpleServer) AcceptLoop() error { for { client, err := p.serverTransport.Accept() if err != nil { select { case <-p.quit: return nil default: } return err } if client != nil { go func() { if err := p.processRequests(client); err != nil { log.Println("error processing request:", err) } }() } } } func (p *TSimpleServer) Serve() error { err := p.Listen() if err != nil { return err } p.AcceptLoop() return nil } var once sync.Once func (p *TSimpleServer) Stop() error { q := func() { p.quit <- struct{}{} p.serverTransport.Interrupt() } once.Do(q) return nil } func (p *TSimpleServer) processRequests(client TTransport) error { processor := p.processorFactory.GetProcessor(client) inputTransport := p.inputTransportFactory.GetTransport(client) outputTransport := p.outputTransportFactory.GetTransport(client) inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport) defer func() { if e := recover(); e != nil { log.Printf("panic in processor: %s: %s", e, debug.Stack()) } }() if inputTransport != nil { defer inputTransport.Close() } if outputTransport != nil { defer outputTransport.Close() } for { ok, err := processor.Process(inputProtocol, outputProtocol) if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE { return nil } else if err != nil { log.Printf("error processing request: %s", err) return err } if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD { continue } if !ok { break } } return nil } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/socket.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "net" "time" ) type TSocket struct { conn net.Conn addr net.Addr timeout time.Duration } // NewTSocket creates a net.Conn-backed TTransport, given a host and port // // Example: // trans, err := thrift.NewTSocket("localhost:9090") func NewTSocket(hostPort string) (*TSocket, error) { return NewTSocketTimeout(hostPort, 0) } // NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port // it also accepts a timeout as a time.Duration func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, error) { //conn, err := net.DialTimeout(network, address, timeout) addr, err := net.ResolveTCPAddr("tcp", hostPort) if err != nil { return nil, err } return NewTSocketFromAddrTimeout(addr, timeout), nil } // Creates a TSocket from a net.Addr func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket { return &TSocket{addr: addr, timeout: timeout} } // Creates a TSocket from an existing net.Conn func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket { return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout} } // Sets the socket timeout func (p *TSocket) SetTimeout(timeout time.Duration) error { p.timeout = timeout return nil } func (p *TSocket) pushDeadline(read, write bool) { var t time.Time if p.timeout > 0 { t = time.Now().Add(time.Duration(p.timeout)) } if read && write { p.conn.SetDeadline(t) } else if read { p.conn.SetReadDeadline(t) } else if write { p.conn.SetWriteDeadline(t) } } // Connects the socket, creating a new socket object if necessary. func (p *TSocket) Open() error { if p.IsOpen() { return NewTTransportException(ALREADY_OPEN, "Socket already connected.") } if p.addr == nil { return NewTTransportException(NOT_OPEN, "Cannot open nil address.") } if len(p.addr.Network()) == 0 { return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") } if len(p.addr.String()) == 0 { return NewTTransportException(NOT_OPEN, "Cannot open bad address.") } var err error if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.timeout); err != nil { return NewTTransportException(NOT_OPEN, err.Error()) } return nil } // Retrieve the underlying net.Conn func (p *TSocket) Conn() net.Conn { return p.conn } // Returns true if the connection is open func (p *TSocket) IsOpen() bool { if p.conn == nil { return false } return true } // Closes the socket. func (p *TSocket) Close() error { // Close the socket if p.conn != nil { err := p.conn.Close() if err != nil { return err } p.conn = nil } return nil } //Returns the remote address of the socket. func (p *TSocket) Addr() net.Addr { return p.addr } func (p *TSocket) Read(buf []byte) (int, error) { if !p.IsOpen() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(true, false) n, err := p.conn.Read(buf) return n, NewTTransportExceptionFromError(err) } func (p *TSocket) Write(buf []byte) (int, error) { if !p.IsOpen() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(false, true) return p.conn.Write(buf) } func (p *TSocket) Flush() error { return nil } func (p *TSocket) Interrupt() error { if !p.IsOpen() { return nil } return p.conn.Close() } func (p *TSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) return maxSize // the thruth is, we just don't know unless framed is used } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/ssl_server_socket.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "net" "time" "crypto/tls" ) type TSSLServerSocket struct { listener net.Listener addr net.Addr clientTimeout time.Duration interrupted bool cfg *tls.Config } func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket, error) { return NewTSSLServerSocketTimeout(listenAddr, cfg, 0) } func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) { addr, err := net.ResolveTCPAddr("tcp", listenAddr) if err != nil { return nil, err } return &TSSLServerSocket{addr: addr, clientTimeout: clientTimeout, cfg: cfg}, nil } func (p *TSSLServerSocket) Listen() error { if p.IsListening() { return nil } l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg) if err != nil { return err } p.listener = l return nil } func (p *TSSLServerSocket) Accept() (TTransport, error) { if p.interrupted { return nil, errTransportInterrupted } if p.listener == nil { return nil, NewTTransportException(NOT_OPEN, "No underlying server socket") } conn, err := p.listener.Accept() if err != nil { return nil, NewTTransportExceptionFromError(err) } return NewTSSLSocketFromConnTimeout(conn, p.cfg, p.clientTimeout), nil } // Checks whether the socket is listening. func (p *TSSLServerSocket) IsListening() bool { return p.listener != nil } // Connects the socket, creating a new socket object if necessary. func (p *TSSLServerSocket) Open() error { if p.IsListening() { return NewTTransportException(ALREADY_OPEN, "Server socket already open") } if l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg); err != nil { return err } else { p.listener = l } return nil } func (p *TSSLServerSocket) Addr() net.Addr { return p.addr } func (p *TSSLServerSocket) Close() error { defer func() { p.listener = nil }() if p.IsListening() { return p.listener.Close() } return nil } func (p *TSSLServerSocket) Interrupt() error { p.interrupted = true return nil } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/ssl_socket.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "crypto/tls" "net" "time" ) type TSSLSocket struct { conn net.Conn // hostPort contains host:port (e.g. "asdf.com:12345"). The field is // only valid if addr is nil. hostPort string // addr is nil when hostPort is not "", and is only used when the // TSSLSocket is constructed from a net.Addr. addr net.Addr timeout time.Duration cfg *tls.Config } // NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration // // Example: // trans, err := thrift.NewTSSLSocket("localhost:9090", nil) func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) { return NewTSSLSocketTimeout(hostPort, cfg, 0) } // NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port // it also accepts a tls Configuration and a timeout as a time.Duration func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) { return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil } // Creates a TSSLSocket from a net.Addr func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket { return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg} } // Creates a TSSLSocket from an existing net.Conn func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket { return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg} } // Sets the socket timeout func (p *TSSLSocket) SetTimeout(timeout time.Duration) error { p.timeout = timeout return nil } func (p *TSSLSocket) pushDeadline(read, write bool) { var t time.Time if p.timeout > 0 { t = time.Now().Add(time.Duration(p.timeout)) } if read && write { p.conn.SetDeadline(t) } else if read { p.conn.SetReadDeadline(t) } else if write { p.conn.SetWriteDeadline(t) } } // Connects the socket, creating a new socket object if necessary. func (p *TSSLSocket) Open() error { var err error // If we have a hostname, we need to pass the hostname to tls.Dial for // certificate hostname checks. if p.hostPort != "" { if p.conn, err = tls.Dial("tcp", p.hostPort, p.cfg); err != nil { return NewTTransportException(NOT_OPEN, err.Error()) } } else { if p.IsOpen() { return NewTTransportException(ALREADY_OPEN, "Socket already connected.") } if p.addr == nil { return NewTTransportException(NOT_OPEN, "Cannot open nil address.") } if len(p.addr.Network()) == 0 { return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") } if len(p.addr.String()) == 0 { return NewTTransportException(NOT_OPEN, "Cannot open bad address.") } if p.conn, err = tls.Dial(p.addr.Network(), p.addr.String(), p.cfg); err != nil { return NewTTransportException(NOT_OPEN, err.Error()) } } return nil } // Retrieve the underlying net.Conn func (p *TSSLSocket) Conn() net.Conn { return p.conn } // Returns true if the connection is open func (p *TSSLSocket) IsOpen() bool { if p.conn == nil { return false } return true } // Closes the socket. func (p *TSSLSocket) Close() error { // Close the socket if p.conn != nil { err := p.conn.Close() if err != nil { return err } p.conn = nil } return nil } func (p *TSSLSocket) Read(buf []byte) (int, error) { if !p.IsOpen() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(true, false) n, err := p.conn.Read(buf) return n, NewTTransportExceptionFromError(err) } func (p *TSSLSocket) Write(buf []byte) (int, error) { if !p.IsOpen() { return 0, NewTTransportException(NOT_OPEN, "Connection not open") } p.pushDeadline(false, true) return p.conn.Write(buf) } func (p *TSSLSocket) Flush() error { return nil } func (p *TSSLSocket) Interrupt() error { if !p.IsOpen() { return nil } return p.conn.Close() } func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) return maxSize // the thruth is, we just don't know unless framed is used } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "errors" "io" ) var errTransportInterrupted = errors.New("Transport Interrupted") type Flusher interface { Flush() (err error) } type ReadSizeProvider interface { RemainingBytes() (num_bytes uint64) } // Encapsulates the I/O layer type TTransport interface { io.ReadWriteCloser Flusher ReadSizeProvider // Opens the transport for communication Open() error // Returns true if the transport is open IsOpen() bool } type stringWriter interface { WriteString(s string) (n int, err error) } // This is "enchanced" transport with extra capabilities. You need to use one of these // to construct protocol. // Notably, TSocket does not implement this interface, and it is always a mistake to use // TSocket directly in protocol. type TRichTransport interface { io.ReadWriter io.ByteReader io.ByteWriter stringWriter Flusher ReadSizeProvider } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/transport_exception.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "errors" "io" ) type timeoutable interface { Timeout() bool } // Thrift Transport exception type TTransportException interface { TException TypeId() int Err() error } const ( UNKNOWN_TRANSPORT_EXCEPTION = 0 NOT_OPEN = 1 ALREADY_OPEN = 2 TIMED_OUT = 3 END_OF_FILE = 4 ) type tTransportException struct { typeId int err error } func (p *tTransportException) TypeId() int { return p.typeId } func (p *tTransportException) Error() string { return p.err.Error() } func (p *tTransportException) Err() error { return p.err } func NewTTransportException(t int, e string) TTransportException { return &tTransportException{typeId: t, err: errors.New(e)} } func NewTTransportExceptionFromError(e error) TTransportException { if e == nil { return nil } if t, ok := e.(TTransportException); ok { return t } switch v := e.(type) { case TTransportException: return v case timeoutable: if v.Timeout() { return &tTransportException{typeId: TIMED_OUT, err: e} } } if e == io.EOF { return &tTransportException{typeId: END_OF_FILE, err: e} } return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e} } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/transport_exception_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "fmt" "io" "testing" ) type timeout struct{ timedout bool } func (t *timeout) Timeout() bool { return t.timedout } func (t *timeout) Error() string { return fmt.Sprintf("Timeout: %v", t.timedout) } func TestTExceptionTimeout(t *testing.T) { timeout := &timeout{true} exception := NewTTransportExceptionFromError(timeout) if timeout.Error() != exception.Error() { t.Fatalf("Error did not match: expected %q, got %q", timeout.Error(), exception.Error()) } if exception.TypeId() != TIMED_OUT { t.Fatalf("TypeId was not TIMED_OUT: expected %v, got %v", TIMED_OUT, exception.TypeId()) } } func TestTExceptionEOF(t *testing.T) { exception := NewTTransportExceptionFromError(io.EOF) if io.EOF.Error() != exception.Error() { t.Fatalf("Error did not match: expected %q, got %q", io.EOF.Error(), exception.Error()) } if exception.TypeId() != END_OF_FILE { t.Fatalf("TypeId was not END_OF_FILE: expected %v, got %v", END_OF_FILE, exception.TypeId()) } } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/transport_factory.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Factory class used to create wrapped instance of Transports. // This is used primarily in servers, which get Transports from // a ServerTransport and then may want to mutate them (i.e. create // a BufferedTransport from the underlying base transport) type TTransportFactory interface { GetTransport(trans TTransport) TTransport } type tTransportFactory struct{} // Return a wrapped instance of the base Transport. func (p *tTransportFactory) GetTransport(trans TTransport) TTransport { return trans } func NewTTransportFactory() TTransportFactory { return &tTransportFactory{} } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/transport_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "io" "net" "strconv" "testing" ) const TRANSPORT_BINARY_DATA_SIZE = 4096 var ( transport_bdata []byte // test data for writing; same as data transport_header map[string]string ) func init() { transport_bdata = make([]byte, TRANSPORT_BINARY_DATA_SIZE) for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ { transport_bdata[i] = byte((i + 'a') % 255) } transport_header = map[string]string{"key": "User-Agent", "value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"} } func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) if !writeTrans.IsOpen() { t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) } if !readTrans.IsOpen() { t.Fatalf("Transport %T not open: %s", readTrans, readTrans) } _, err := writeTrans.Write(transport_bdata) if err != nil { t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) } err = writeTrans.Flush() if err != nil { t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) } n, err := io.ReadFull(readTrans, buf) if err != nil { t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) } if n != TRANSPORT_BINARY_DATA_SIZE { t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) } for k, v := range buf { if v != transport_bdata[k] { t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) } } _, err = writeTrans.Write(transport_bdata) if err != nil { t.Fatalf("Transport %T cannot write binary data 2 of length %d: %s", writeTrans, len(transport_bdata), err) } err = writeTrans.Flush() if err != nil { t.Fatalf("Transport %T cannot flush write binary data 2: %s", writeTrans, err) } buf = make([]byte, TRANSPORT_BINARY_DATA_SIZE) read := 1 for n = 0; n < TRANSPORT_BINARY_DATA_SIZE && read != 0; { read, err = readTrans.Read(buf[n:]) if err != nil { t.Errorf("Transport %T cannot read binary data 2 of total length %d from offset %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, n, err) } n += read } if n != TRANSPORT_BINARY_DATA_SIZE { t.Errorf("Transport %T read only %d instead of %d bytes of binary data 2", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) } for k, v := range buf { if v != transport_bdata[k] { t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) } } } func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) if !writeTrans.IsOpen() { t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) } if !readTrans.IsOpen() { t.Fatalf("Transport %T not open: %s", readTrans, readTrans) } // Need to assert type of TTransport to THttpClient to expose the Setter httpWPostTrans := writeTrans.(*THttpClient) httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"]) _, err := writeTrans.Write(transport_bdata) if err != nil { t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) } err = writeTrans.Flush() if err != nil { t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) } // Need to assert type of TTransport to THttpClient to expose the Getter httpRPostTrans := readTrans.(*THttpClient) readHeader := httpRPostTrans.GetHeader(transport_header["key"]) if err != nil { t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans) } if transport_header["value"] != readHeader { t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader) } n, err := io.ReadFull(readTrans, buf) if err != nil { t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) } if n != TRANSPORT_BINARY_DATA_SIZE { t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) } for k, v := range buf { if v != transport_bdata[k] { t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) } } } func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) { err := readTrans.Close() if err != nil { t.Errorf("Transport %T cannot close read transport: %s", readTrans, err) } if writeTrans != readTrans { err = writeTrans.Close() if err != nil { t.Errorf("Transport %T cannot close write transport: %s", writeTrans, err) } } } func FindAvailableTCPServerPort(startPort int) (net.Addr, error) { for i := startPort; i < 65535; i++ { s := "127.0.0.1:" + strconv.Itoa(i) l, err := net.Listen("tcp", s) if err == nil { l.Close() return net.ResolveTCPAddr("tcp", s) } } return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port") } func valueInSlice(value string, slice []string) bool { for _, v := range slice { if value == v { return true } } return false } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/type.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift // Type constants in the Thrift protocol type TType byte const ( STOP = 0 VOID = 1 BOOL = 2 BYTE = 3 I08 = 3 DOUBLE = 4 I16 = 6 I32 = 8 I64 = 10 STRING = 11 UTF7 = 11 STRUCT = 12 MAP = 13 SET = 14 LIST = 15 UTF8 = 16 UTF16 = 17 //BINARY = 18 wrong and unusued ) var typeNames = map[int]string{ STOP: "STOP", VOID: "VOID", BOOL: "BOOL", BYTE: "BYTE", DOUBLE: "DOUBLE", I16: "I16", I32: "I32", I64: "I64", STRING: "STRING", STRUCT: "STRUCT", MAP: "MAP", SET: "SET", LIST: "LIST", UTF8: "UTF8", UTF16: "UTF16", } func (p TType) String() string { if s, ok := typeNames[int(p)]; ok { return s } return "Unknown" } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/zlib_transport.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "compress/zlib" "io" "log" ) // TZlibTransportFactory is a factory for TZlibTransport instances type TZlibTransportFactory struct { level int } // TZlibTransport is a TTransport implementation that makes use of zlib compression. type TZlibTransport struct { reader io.ReadCloser transport TTransport writer *zlib.Writer } // GetTransport constructs a new instance of NewTZlibTransport func (p *TZlibTransportFactory) GetTransport(trans TTransport) TTransport { t, _ := NewTZlibTransport(trans, p.level) return t } // NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory func NewTZlibTransportFactory(level int) *TZlibTransportFactory { return &TZlibTransportFactory{level: level} } // NewTZlibTransport constructs a new instance of TZlibTransport func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) { w, err := zlib.NewWriterLevel(trans, level) if err != nil { log.Println(err) return nil, err } return &TZlibTransport{ writer: w, transport: trans, }, nil } // Close closes the reader and writer (flushing any unwritten data) and closes // the underlying transport. func (z *TZlibTransport) Close() error { if z.reader != nil { if err := z.reader.Close(); err != nil { return err } } if err := z.writer.Close(); err != nil { return err } return z.transport.Close() } // Flush flushes the writer and its underlying transport. func (z *TZlibTransport) Flush() error { if err := z.writer.Flush(); err != nil { return err } return z.transport.Flush() } // IsOpen returns true if the transport is open func (z *TZlibTransport) IsOpen() bool { return z.transport.IsOpen() } // Open opens the transport for communication func (z *TZlibTransport) Open() error { return z.transport.Open() } func (z *TZlibTransport) Read(p []byte) (int, error) { if z.reader == nil { r, err := zlib.NewReader(z.transport) if err != nil { return 0, NewTTransportExceptionFromError(err) } z.reader = r } return z.reader.Read(p) } // RemainingBytes returns the size in bytes of the data that is still to be // read. func (z *TZlibTransport) RemainingBytes() uint64 { return z.transport.RemainingBytes() } func (z *TZlibTransport) Write(p []byte) (int, error) { return z.writer.Write(p) } ================================================ FILE: thirdparty/github.com/apache/thrift/lib/go/thrift/zlib_transport_test.go ================================================ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package thrift import ( "compress/zlib" "testing" ) func TestZlibTransport(t *testing.T) { trans, err := NewTZlibTransport(NewTMemoryBuffer(), zlib.BestCompression) if err != nil { t.Fatal(err) } TransportTest(t, trans, trans) } ================================================ FILE: thrift/arg2/kv_iterator.go ================================================ // Package arg2 contains tchannel thrift Arg2 interfaces for external use. // // These interfaces are currently unstable, and aren't covered by the API // backwards-compatibility guarantee. package arg2 import ( "encoding/binary" "io" "github.com/uber/tchannel-go/typed" ) // KeyValIterator is a iterator for reading tchannel-thrift Arg2 Scheme, // which has key/value pairs (k~2 v~2). // NOTE: to be optimized for performance, we try to limit the allocation // done in the process of iteration. type KeyValIterator struct { remaining []byte leftPairCount int key []byte val []byte } // NewKeyValIterator inits a KeyValIterator with the buffer pointing at // start of Arg2. Return io.EOF if no iterator is available. // NOTE: tchannel-thrift Arg Scheme starts with number of key/value pair. func NewKeyValIterator(arg2Payload []byte) (KeyValIterator, error) { if len(arg2Payload) < 2 { return KeyValIterator{}, io.EOF } leftPairCount := binary.BigEndian.Uint16(arg2Payload[0:2]) return KeyValIterator{ leftPairCount: int(leftPairCount), remaining: arg2Payload[2:], }.Next() } // Key Returns the key. func (i KeyValIterator) Key() []byte { return i.key } // Value returns value. func (i KeyValIterator) Value() []byte { return i.val } // Remaining returns whether there's any pairs left to consume. func (i KeyValIterator) Remaining() bool { return i.leftPairCount > 0 } // Next returns next iterator. Return io.EOF if no more key/value pair is // available. // // Note: We used named returns because of an unexpected performance improvement // See https://github.com/golang/go/issues/40638 func (i KeyValIterator) Next() (kv KeyValIterator, _ error) { if i.leftPairCount <= 0 { return KeyValIterator{}, io.EOF } rbuf := typed.NewReadBuffer(i.remaining) keyLen := int(rbuf.ReadUint16()) key := rbuf.ReadBytes(keyLen) valLen := int(rbuf.ReadUint16()) val := rbuf.ReadBytes(valLen) if rbuf.Err() != nil { return KeyValIterator{}, rbuf.Err() } leftPairCount := i.leftPairCount - 1 kv = KeyValIterator{ remaining: rbuf.Remaining(), leftPairCount: leftPairCount, key: key, val: val, } return kv, nil } ================================================ FILE: thrift/arg2/kv_iterator_test.go ================================================ package arg2 import ( "fmt" "io" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/testutils/thriftarg2test" ) func TestKeyValIterator(t *testing.T) { const ( testBufSize = 100 nh = 5 ) kv := make(map[string]string, nh) for i := 0; i < nh; i++ { kv[fmt.Sprintf("key%v", i)] = fmt.Sprintf("value%v", i) } buf := thriftarg2test.BuildKVBuffer(kv) iter, err := NewKeyValIterator(buf) gotKV := make(map[string]string) for i := 0; i < nh; i++ { assert.NoError(t, err) gotKV[fmt.Sprintf("key%v", i)] = fmt.Sprintf("value%v", i) remaining := iter.Remaining() iter, err = iter.Next() assert.Equal(t, err == nil, remaining, "Expect remaining to be true if there's no errors") } assert.Equal(t, io.EOF, err) assert.Equal(t, kv, gotKV) t.Run("init iterator w/o Arg2", func(t *testing.T) { _, err := NewKeyValIterator(nil) assert.Equal(t, io.EOF, err) }) t.Run("init iterator w/o pairs", func(t *testing.T) { buf := thriftarg2test.BuildKVBuffer(nil /*kv*/) _, err := NewKeyValIterator(buf) assert.Equal(t, io.EOF, err) }) t.Run("bad key value length", func(t *testing.T) { buf := thriftarg2test.BuildKVBuffer(map[string]string{ "key": "value", }) tests := []struct { msg string arg2Len int wantErr string }{ { msg: "ok", arg2Len: len(buf), }, { msg: "not enough to read key len", arg2Len: 3, // nh (2) + 1 wantErr: "buffer is too small", }, { msg: "not enough to hold key value", arg2Len: 6, // nh (2) + 2 + len(key) - 1 wantErr: "buffer is too small", }, { msg: "not enough to read value len", arg2Len: 8, // nh (2) + 2 + len(key) + 1 wantErr: "buffer is too small", }, { msg: "not enough to iterate value", arg2Len: 13, // nh (2) + 2 + len(key) + 2 + len(value) = 14 wantErr: "buffer is too small", }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { iter, err := NewKeyValIterator(buf[:tt.arg2Len]) if tt.wantErr == "" { assert.NoError(t, err) assert.Equal(t, "key", string(iter.Key()), "unexpected key") assert.Equal(t, "value", string(iter.Value()), "unexpected value") return } require.Error(t, err, "should not create iterator") assert.Contains(t, err.Error(), tt.wantErr) }) } }) } func BenchmarkKeyValIterator(b *testing.B) { kvBuffer := thriftarg2test.BuildKVBuffer(map[string]string{ "foo": "bar", "baz": "qux", "quux": "corge", }) for i := 0; i < b.N; i++ { iter, err := NewKeyValIterator(kvBuffer) if err != nil { b.Fatalf("unexpected err %v", err) } for iter.Remaining() { iter, err = iter.Next() if err != nil { b.Fatalf("unexpected err %v", err) } } } } ================================================ FILE: thrift/client.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/internal/argreader" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "golang.org/x/net/context" ) // client implements TChanClient and makes outgoing Thrift calls. type client struct { ch *tchannel.Channel sc *tchannel.SubChannel serviceName string opts ClientOptions } // ClientOptions are options to customize the client. type ClientOptions struct { // HostPort specifies a specific server to hit. HostPort string } // NewClient returns a Client that makes calls over the given tchannel to the given Hyperbahn service. func NewClient(ch *tchannel.Channel, serviceName string, opts *ClientOptions) TChanClient { client := &client{ ch: ch, sc: ch.GetSubChannel(serviceName), serviceName: serviceName, } if opts != nil { client.opts = *opts } return client } func (c *client) startCall(ctx context.Context, method string, callOptions *tchannel.CallOptions) (*tchannel.OutboundCall, error) { if c.opts.HostPort != "" { return c.ch.BeginCall(ctx, c.opts.HostPort, c.serviceName, method, callOptions) } return c.sc.BeginCall(ctx, method, callOptions) } func writeArgs(call *tchannel.OutboundCall, headers map[string]string, req thrift.TStruct) error { writer, err := call.Arg2Writer() if err != nil { return err } headers = tchannel.InjectOutboundSpan(call.Response(), headers) if err := WriteHeaders(writer, headers); err != nil { return err } if err := writer.Close(); err != nil { return err } writer, err = call.Arg3Writer() if err != nil { return err } if err := WriteStruct(writer, req); err != nil { return err } return writer.Close() } // readResponse reads the response struct into resp, and returns: // (response headers, whether there was an application error, unexpected error). func readResponse(response *tchannel.OutboundCallResponse, resp thrift.TStruct) (map[string]string, bool, error) { reader, err := response.Arg2Reader() if err != nil { return nil, false, err } headers, err := ReadHeaders(reader) if err != nil { return nil, false, err } if err := argreader.EnsureEmpty(reader, "reading response headers"); err != nil { return nil, false, err } if err := reader.Close(); err != nil { return nil, false, err } success := !response.ApplicationError() reader, err = response.Arg3Reader() if err != nil { return headers, success, err } if err := ReadStruct(reader, resp); err != nil { return headers, success, err } if err := argreader.EnsureEmpty(reader, "reading response body"); err != nil { return nil, false, err } return headers, success, reader.Close() } func (c *client) Call(ctx Context, thriftService, methodName string, req, resp thrift.TStruct) (bool, error) { var ( headers = ctx.Headers() respHeaders map[string]string isOK bool ) err := c.ch.RunWithRetry(ctx, func(ctx context.Context, rs *tchannel.RequestState) error { respHeaders, isOK = nil, false call, err := c.startCall(ctx, thriftService+"::"+methodName, &tchannel.CallOptions{ Format: tchannel.Thrift, RequestState: rs, }) if err != nil { return err } if err := writeArgs(call, headers, req); err != nil { return err } respHeaders, isOK, err = readResponse(call.Response(), resp) return err }) if err != nil { return false, err } ctx.SetResponseHeaders(respHeaders) return isOK, nil } ================================================ FILE: thrift/context.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "time" "github.com/uber/tchannel-go" "golang.org/x/net/context" ) // Context is a Thrift Context which contains request and response headers. type Context tchannel.ContextWithHeaders // NewContext returns a Context that can be used to make Thrift calls. func NewContext(timeout time.Duration) (Context, context.CancelFunc) { ctx, cancel := tchannel.NewContext(timeout) return Wrap(ctx), cancel } // Wrap returns a Thrift Context that wraps around a Context. func Wrap(ctx context.Context) Context { return tchannel.Wrap(ctx) } // WithHeaders returns a Context that can be used to make a call with request headers. func WithHeaders(ctx context.Context, headers map[string]string) Context { return tchannel.WrapWithHeaders(ctx, headers) } ================================================ FILE: thrift/context_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift_test import ( "errors" "testing" "time" . "github.com/uber/tchannel-go/thrift" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" gen "github.com/uber/tchannel-go/thrift/gen-go/test" "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) func TestWrapContext(t *testing.T) { tctx, cancel := NewContext(time.Second) defer cancel() headers := map[string]string{"h1": "v1"} ctx := context.WithValue(WithHeaders(tctx, headers), "1", "2") wrapped := Wrap(ctx) assert.NotNil(t, wrapped, "Should not return nil.") assert.Equal(t, headers, wrapped.Headers(), "Unexpected headers") assert.Equal(t, "2", wrapped.Value("1"), "Unexpected value") } func TestContextBuilder(t *testing.T) { ctx, cancel := tchannel.NewContextBuilder(time.Second).SetShardKey("shard").Build() defer cancel() var called bool testutils.WithServer(t, nil, func(ch *tchannel.Channel, hostPort string) { peerInfo := ch.PeerInfo() testutils.RegisterFunc(ch, "SecondService::Echo", func(ctx context.Context, args *raw.Args) (*raw.Res, error) { call := tchannel.CurrentCall(ctx) assert.Equal(t, peerInfo.ServiceName, call.CallerName(), "unexpected caller name") assert.Equal(t, "shard", call.ShardKey(), "unexpected shard key") assert.Equal(t, tchannel.Thrift, args.Format) called = true return nil, errors.New("err") }) client := NewClient(ch, ch.PeerInfo().ServiceName, &ClientOptions{ HostPort: peerInfo.HostPort, }) secondClient := gen.NewTChanSecondServiceClient(client) secondClient.Echo(ctx, "asd") assert.True(t, called, "test not called") }) } ================================================ FILE: thrift/doc.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. /* Package thrift adds support to use Thrift services over TChannel. To start listening to a Thrift service using TChannel, create the channel, and register the service using: server := thrift.NewServer(tchan) server.Register(gen.NewTChan[SERVICE]Server(handler) // Any number of services can be registered on the same Thrift server. server.Register(gen.NewTChan[SERVICE2]Server(handler) To use a Thrift client use the generated TChan client: thriftClient := thrift.NewClient(ch, "hyperbahnService", nil) client := gen.NewTChan[SERVICE]Client(thriftClient) // Any number of service clients can be made using the same Thrift client. client2 := gen.NewTChan[SERVICE2]Client(thriftClient) This client can be used similar to a standard Thrift client, except a Context is passed with options (such as timeout). TODO(prashant): Add and document header support. */ package thrift ================================================ FILE: thrift/errors_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift_test import ( "testing" "time" // Test is in a separate package to avoid circular dependencies. . "github.com/uber/tchannel-go/thrift" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/raw" "github.com/uber/tchannel-go/testutils" gen "github.com/uber/tchannel-go/thrift/gen-go/test" "github.com/uber/tchannel-go/thrift/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) func serializeStruct(t *testing.T, s thrift.TStruct) []byte { trans := thrift.NewTMemoryBuffer() p := thrift.NewTBinaryProtocolTransport(trans) require.NoError(t, s.Write(p), "Struct serialization failed") return trans.Bytes() } func TestInvalidThriftBytes(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() ch := testutils.NewClient(t, nil) sCh := testutils.NewServer(t, nil) defer sCh.Close() svr := NewServer(sCh) svr.Register(gen.NewTChanSecondServiceServer(new(mocks.TChanSecondService))) tests := []struct { name string arg3 []byte }{ { name: "missing bytes", arg3: serializeStruct(t, &gen.SecondServiceEchoArgs{Arg: "Hello world"})[:5], }, { name: "wrong struct", arg3: serializeStruct(t, &gen.Data{B1: true}), }, } for _, tt := range tests { sPeer := sCh.PeerInfo() call, err := ch.BeginCall(ctx, sPeer.HostPort, sPeer.ServiceName, "SecondService::Echo", &tchannel.CallOptions{ Format: tchannel.Thrift, }) require.NoError(t, err, "BeginCall failed") require.NoError(t, tchannel.NewArgWriter(call.Arg2Writer()).Write([]byte{0, 0}), "Write arg2 failed") writer, err := call.Arg3Writer() require.NoError(t, err, "Arg3Writer failed") _, err = writer.Write(tt.arg3) require.NoError(t, err, "Write arg3 failed") require.NoError(t, writer.Close(), "Close failed") response := call.Response() _, _, err = raw.ReadArgsV2(response) assert.Error(t, err, "%v: Expected error", tt.name) assert.Equal(t, tchannel.ErrCodeBadRequest, tchannel.GetSystemErrorCode(err), "%v: Expected bad request, got %v", tt.name, err) } } ================================================ FILE: thrift/gen-go/meta/constants.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package meta import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal func init() { } ================================================ FILE: thrift/gen-go/meta/meta.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package meta import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type Meta interface { // Parameters: // - Hr Health(hr *HealthRequest) (r *HealthStatus, err error) ThriftIDL() (r *ThriftIDLs, err error) VersionInfo() (r *VersionInfo, err error) } type MetaClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewMetaClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *MetaClient { return &MetaClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewMetaClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *MetaClient { return &MetaClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } // Parameters: // - Hr func (p *MetaClient) Health(hr *HealthRequest) (r *HealthStatus, err error) { if err = p.sendHealth(hr); err != nil { return } return p.recvHealth() } func (p *MetaClient) sendHealth(hr *HealthRequest) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("health", thrift.CALL, p.SeqId); err != nil { return } args := MetaHealthArgs{ Hr: hr, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *MetaClient) recvHealth() (value *HealthStatus, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "health" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "health failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "health failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error2 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error3 error error3, err = error2.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error3 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "health failed: invalid message type") return } result := MetaHealthResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } func (p *MetaClient) ThriftIDL() (r *ThriftIDLs, err error) { if err = p.sendThriftIDL(); err != nil { return } return p.recvThriftIDL() } func (p *MetaClient) sendThriftIDL() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("thriftIDL", thrift.CALL, p.SeqId); err != nil { return } args := MetaThriftIDLArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *MetaClient) recvThriftIDL() (value *ThriftIDLs, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "thriftIDL" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "thriftIDL failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "thriftIDL failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error4 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error5 error error5, err = error4.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error5 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "thriftIDL failed: invalid message type") return } result := MetaThriftIDLResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } func (p *MetaClient) VersionInfo() (r *VersionInfo, err error) { if err = p.sendVersionInfo(); err != nil { return } return p.recvVersionInfo() } func (p *MetaClient) sendVersionInfo() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("versionInfo", thrift.CALL, p.SeqId); err != nil { return } args := MetaVersionInfoArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *MetaClient) recvVersionInfo() (value *VersionInfo, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "versionInfo" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "versionInfo failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "versionInfo failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error6 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error7 error error7, err = error6.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error7 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "versionInfo failed: invalid message type") return } result := MetaVersionInfoResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } type MetaProcessor struct { processorMap map[string]thrift.TProcessorFunction handler Meta } func (p *MetaProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *MetaProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *MetaProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewMetaProcessor(handler Meta) *MetaProcessor { self8 := &MetaProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self8.processorMap["health"] = &metaProcessorHealth{handler: handler} self8.processorMap["thriftIDL"] = &metaProcessorThriftIDL{handler: handler} self8.processorMap["versionInfo"] = &metaProcessorVersionInfo{handler: handler} return self8 } func (p *MetaProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x9 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x9.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x9 } type metaProcessorHealth struct { handler Meta } func (p *metaProcessorHealth) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := MetaHealthArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("health", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := MetaHealthResult{} var retval *HealthStatus var err2 error if retval, err2 = p.handler.Health(args.Hr); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing health: "+err2.Error()) oprot.WriteMessageBegin("health", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("health", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type metaProcessorThriftIDL struct { handler Meta } func (p *metaProcessorThriftIDL) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := MetaThriftIDLArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("thriftIDL", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := MetaThriftIDLResult{} var retval *ThriftIDLs var err2 error if retval, err2 = p.handler.ThriftIDL(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing thriftIDL: "+err2.Error()) oprot.WriteMessageBegin("thriftIDL", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("thriftIDL", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type metaProcessorVersionInfo struct { handler Meta } func (p *metaProcessorVersionInfo) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := MetaVersionInfoArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("versionInfo", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := MetaVersionInfoResult{} var retval *VersionInfo var err2 error if retval, err2 = p.handler.VersionInfo(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing versionInfo: "+err2.Error()) oprot.WriteMessageBegin("versionInfo", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("versionInfo", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES // Attributes: // - Hr type MetaHealthArgs struct { Hr *HealthRequest `thrift:"hr,1" db:"hr" json:"hr"` } func NewMetaHealthArgs() *MetaHealthArgs { return &MetaHealthArgs{} } var MetaHealthArgs_Hr_DEFAULT *HealthRequest func (p *MetaHealthArgs) GetHr() *HealthRequest { if !p.IsSetHr() { return MetaHealthArgs_Hr_DEFAULT } return p.Hr } func (p *MetaHealthArgs) IsSetHr() bool { return p.Hr != nil } func (p *MetaHealthArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaHealthArgs) ReadField1(iprot thrift.TProtocol) error { p.Hr = &HealthRequest{} if err := p.Hr.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Hr), err) } return nil } func (p *MetaHealthArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("health_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaHealthArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("hr", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:hr: ", p), err) } if err := p.Hr.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Hr), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:hr: ", p), err) } return err } func (p *MetaHealthArgs) String() string { if p == nil { return "" } return fmt.Sprintf("MetaHealthArgs(%+v)", *p) } // Attributes: // - Success type MetaHealthResult struct { Success *HealthStatus `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewMetaHealthResult() *MetaHealthResult { return &MetaHealthResult{} } var MetaHealthResult_Success_DEFAULT *HealthStatus func (p *MetaHealthResult) GetSuccess() *HealthStatus { if !p.IsSetSuccess() { return MetaHealthResult_Success_DEFAULT } return p.Success } func (p *MetaHealthResult) IsSetSuccess() bool { return p.Success != nil } func (p *MetaHealthResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaHealthResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &HealthStatus{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *MetaHealthResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("health_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaHealthResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *MetaHealthResult) String() string { if p == nil { return "" } return fmt.Sprintf("MetaHealthResult(%+v)", *p) } type MetaThriftIDLArgs struct { } func NewMetaThriftIDLArgs() *MetaThriftIDLArgs { return &MetaThriftIDLArgs{} } func (p *MetaThriftIDLArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaThriftIDLArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("thriftIDL_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaThriftIDLArgs) String() string { if p == nil { return "" } return fmt.Sprintf("MetaThriftIDLArgs(%+v)", *p) } // Attributes: // - Success type MetaThriftIDLResult struct { Success *ThriftIDLs `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewMetaThriftIDLResult() *MetaThriftIDLResult { return &MetaThriftIDLResult{} } var MetaThriftIDLResult_Success_DEFAULT *ThriftIDLs func (p *MetaThriftIDLResult) GetSuccess() *ThriftIDLs { if !p.IsSetSuccess() { return MetaThriftIDLResult_Success_DEFAULT } return p.Success } func (p *MetaThriftIDLResult) IsSetSuccess() bool { return p.Success != nil } func (p *MetaThriftIDLResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaThriftIDLResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &ThriftIDLs{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *MetaThriftIDLResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("thriftIDL_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaThriftIDLResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *MetaThriftIDLResult) String() string { if p == nil { return "" } return fmt.Sprintf("MetaThriftIDLResult(%+v)", *p) } type MetaVersionInfoArgs struct { } func NewMetaVersionInfoArgs() *MetaVersionInfoArgs { return &MetaVersionInfoArgs{} } func (p *MetaVersionInfoArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaVersionInfoArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("versionInfo_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaVersionInfoArgs) String() string { if p == nil { return "" } return fmt.Sprintf("MetaVersionInfoArgs(%+v)", *p) } // Attributes: // - Success type MetaVersionInfoResult struct { Success *VersionInfo `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewMetaVersionInfoResult() *MetaVersionInfoResult { return &MetaVersionInfoResult{} } var MetaVersionInfoResult_Success_DEFAULT *VersionInfo func (p *MetaVersionInfoResult) GetSuccess() *VersionInfo { if !p.IsSetSuccess() { return MetaVersionInfoResult_Success_DEFAULT } return p.Success } func (p *MetaVersionInfoResult) IsSetSuccess() bool { return p.Success != nil } func (p *MetaVersionInfoResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaVersionInfoResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &VersionInfo{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *MetaVersionInfoResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("versionInfo_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaVersionInfoResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *MetaVersionInfoResult) String() string { if p == nil { return "" } return fmt.Sprintf("MetaVersionInfoResult(%+v)", *p) } ================================================ FILE: thrift/gen-go/meta/ttypes.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package meta import ( "bytes" "database/sql/driver" "errors" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal var GoUnusedProtection__ int type HealthState int64 const ( HealthState_REFUSING HealthState = 0 HealthState_ACCEPTING HealthState = 1 HealthState_STOPPING HealthState = 2 HealthState_STOPPED HealthState = 3 ) func (p HealthState) String() string { switch p { case HealthState_REFUSING: return "REFUSING" case HealthState_ACCEPTING: return "ACCEPTING" case HealthState_STOPPING: return "STOPPING" case HealthState_STOPPED: return "STOPPED" } return "" } func HealthStateFromString(s string) (HealthState, error) { switch s { case "REFUSING": return HealthState_REFUSING, nil case "ACCEPTING": return HealthState_ACCEPTING, nil case "STOPPING": return HealthState_STOPPING, nil case "STOPPED": return HealthState_STOPPED, nil } return HealthState(0), fmt.Errorf("not a valid HealthState string") } func HealthStatePtr(v HealthState) *HealthState { return &v } func (p HealthState) MarshalText() ([]byte, error) { return []byte(p.String()), nil } func (p *HealthState) UnmarshalText(text []byte) error { q, err := HealthStateFromString(string(text)) if err != nil { return err } *p = q return nil } func (p *HealthState) Scan(value interface{}) error { v, ok := value.(int64) if !ok { return errors.New("Scan value is not int64") } *p = HealthState(v) return nil } func (p *HealthState) Value() (driver.Value, error) { if p == nil { return nil, nil } return int64(*p), nil } type HealthRequestType int64 const ( HealthRequestType_PROCESS HealthRequestType = 0 HealthRequestType_TRAFFIC HealthRequestType = 1 ) func (p HealthRequestType) String() string { switch p { case HealthRequestType_PROCESS: return "PROCESS" case HealthRequestType_TRAFFIC: return "TRAFFIC" } return "" } func HealthRequestTypeFromString(s string) (HealthRequestType, error) { switch s { case "PROCESS": return HealthRequestType_PROCESS, nil case "TRAFFIC": return HealthRequestType_TRAFFIC, nil } return HealthRequestType(0), fmt.Errorf("not a valid HealthRequestType string") } func HealthRequestTypePtr(v HealthRequestType) *HealthRequestType { return &v } func (p HealthRequestType) MarshalText() ([]byte, error) { return []byte(p.String()), nil } func (p *HealthRequestType) UnmarshalText(text []byte) error { q, err := HealthRequestTypeFromString(string(text)) if err != nil { return err } *p = q return nil } func (p *HealthRequestType) Scan(value interface{}) error { v, ok := value.(int64) if !ok { return errors.New("Scan value is not int64") } *p = HealthRequestType(v) return nil } func (p *HealthRequestType) Value() (driver.Value, error) { if p == nil { return nil, nil } return int64(*p), nil } type Filename string func FilenamePtr(v Filename) *Filename { return &v } // Attributes: // - Type type HealthRequest struct { Type *HealthRequestType `thrift:"type,1" db:"type" json:"type,omitempty"` } func NewHealthRequest() *HealthRequest { return &HealthRequest{} } var HealthRequest_Type_DEFAULT HealthRequestType func (p *HealthRequest) GetType() HealthRequestType { if !p.IsSetType() { return HealthRequest_Type_DEFAULT } return *p.Type } func (p *HealthRequest) IsSetType() bool { return p.Type != nil } func (p *HealthRequest) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *HealthRequest) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { temp := HealthRequestType(v) p.Type = &temp } return nil } func (p *HealthRequest) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("HealthRequest"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *HealthRequest) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetType() { if err := oprot.WriteFieldBegin("type", thrift.I32, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:type: ", p), err) } if err := oprot.WriteI32(int32(*p.Type)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.type (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:type: ", p), err) } } return err } func (p *HealthRequest) String() string { if p == nil { return "" } return fmt.Sprintf("HealthRequest(%+v)", *p) } // Attributes: // - Ok // - Message // - State type HealthStatus struct { Ok bool `thrift:"ok,1,required" db:"ok" json:"ok"` Message *string `thrift:"message,2" db:"message" json:"message,omitempty"` State *HealthState `thrift:"state,3" db:"state" json:"state,omitempty"` } func NewHealthStatus() *HealthStatus { return &HealthStatus{} } func (p *HealthStatus) GetOk() bool { return p.Ok } var HealthStatus_Message_DEFAULT string func (p *HealthStatus) GetMessage() string { if !p.IsSetMessage() { return HealthStatus_Message_DEFAULT } return *p.Message } var HealthStatus_State_DEFAULT HealthState func (p *HealthStatus) GetState() HealthState { if !p.IsSetState() { return HealthStatus_State_DEFAULT } return *p.State } func (p *HealthStatus) IsSetMessage() bool { return p.Message != nil } func (p *HealthStatus) IsSetState() bool { return p.State != nil } func (p *HealthStatus) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetOk bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetOk = true case 2: if err := p.ReadField2(iprot); err != nil { return err } case 3: if err := p.ReadField3(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetOk { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Ok is not set")) } return nil } func (p *HealthStatus) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadBool(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Ok = v } return nil } func (p *HealthStatus) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.Message = &v } return nil } func (p *HealthStatus) ReadField3(iprot thrift.TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return thrift.PrependError("error reading field 3: ", err) } else { temp := HealthState(v) p.State = &temp } return nil } func (p *HealthStatus) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("HealthStatus"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := p.writeField3(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *HealthStatus) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("ok", thrift.BOOL, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:ok: ", p), err) } if err := oprot.WriteBool(bool(p.Ok)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.ok (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:ok: ", p), err) } return err } func (p *HealthStatus) writeField2(oprot thrift.TProtocol) (err error) { if p.IsSetMessage() { if err := oprot.WriteFieldBegin("message", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:message: ", p), err) } if err := oprot.WriteString(string(*p.Message)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.message (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:message: ", p), err) } } return err } func (p *HealthStatus) writeField3(oprot thrift.TProtocol) (err error) { if p.IsSetState() { if err := oprot.WriteFieldBegin("state", thrift.I32, 3); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:state: ", p), err) } if err := oprot.WriteI32(int32(*p.State)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.state (3) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 3:state: ", p), err) } } return err } func (p *HealthStatus) String() string { if p == nil { return "" } return fmt.Sprintf("HealthStatus(%+v)", *p) } // Attributes: // - Idls // - EntryPoint type ThriftIDLs struct { Idls map[Filename]string `thrift:"idls,1,required" db:"idls" json:"idls"` EntryPoint Filename `thrift:"entryPoint,2,required" db:"entryPoint" json:"entryPoint"` } func NewThriftIDLs() *ThriftIDLs { return &ThriftIDLs{} } func (p *ThriftIDLs) GetIdls() map[Filename]string { return p.Idls } func (p *ThriftIDLs) GetEntryPoint() Filename { return p.EntryPoint } func (p *ThriftIDLs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetIdls bool = false var issetEntryPoint bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetIdls = true case 2: if err := p.ReadField2(iprot); err != nil { return err } issetEntryPoint = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetIdls { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Idls is not set")) } if !issetEntryPoint { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field EntryPoint is not set")) } return nil } func (p *ThriftIDLs) ReadField1(iprot thrift.TProtocol) error { _, _, size, err := iprot.ReadMapBegin() if err != nil { return thrift.PrependError("error reading map begin: ", err) } tMap := make(map[Filename]string, size) p.Idls = tMap for i := 0; i < size; i++ { var _key0 Filename if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { temp := Filename(v) _key0 = temp } var _val1 string if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { _val1 = v } p.Idls[_key0] = _val1 } if err := iprot.ReadMapEnd(); err != nil { return thrift.PrependError("error reading map end: ", err) } return nil } func (p *ThriftIDLs) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { temp := Filename(v) p.EntryPoint = temp } return nil } func (p *ThriftIDLs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("ThriftIDLs"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *ThriftIDLs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("idls", thrift.MAP, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:idls: ", p), err) } if err := oprot.WriteMapBegin(thrift.STRING, thrift.STRING, len(p.Idls)); err != nil { return thrift.PrependError("error writing map begin: ", err) } for k, v := range p.Idls { if err := oprot.WriteString(string(k)); err != nil { return thrift.PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } if err := oprot.WriteString(string(v)); err != nil { return thrift.PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } if err := oprot.WriteMapEnd(); err != nil { return thrift.PrependError("error writing map end: ", err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:idls: ", p), err) } return err } func (p *ThriftIDLs) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("entryPoint", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:entryPoint: ", p), err) } if err := oprot.WriteString(string(p.EntryPoint)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.entryPoint (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:entryPoint: ", p), err) } return err } func (p *ThriftIDLs) String() string { if p == nil { return "" } return fmt.Sprintf("ThriftIDLs(%+v)", *p) } // Attributes: // - Language // - LanguageVersion // - Version type VersionInfo struct { Language string `thrift:"language,1,required" db:"language" json:"language"` LanguageVersion string `thrift:"language_version,2,required" db:"language_version" json:"language_version"` Version string `thrift:"version,3,required" db:"version" json:"version"` } func NewVersionInfo() *VersionInfo { return &VersionInfo{} } func (p *VersionInfo) GetLanguage() string { return p.Language } func (p *VersionInfo) GetLanguageVersion() string { return p.LanguageVersion } func (p *VersionInfo) GetVersion() string { return p.Version } func (p *VersionInfo) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetLanguage bool = false var issetLanguageVersion bool = false var issetVersion bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetLanguage = true case 2: if err := p.ReadField2(iprot); err != nil { return err } issetLanguageVersion = true case 3: if err := p.ReadField3(iprot); err != nil { return err } issetVersion = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetLanguage { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Language is not set")) } if !issetLanguageVersion { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field LanguageVersion is not set")) } if !issetVersion { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Version is not set")) } return nil } func (p *VersionInfo) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Language = v } return nil } func (p *VersionInfo) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.LanguageVersion = v } return nil } func (p *VersionInfo) ReadField3(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 3: ", err) } else { p.Version = v } return nil } func (p *VersionInfo) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("VersionInfo"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := p.writeField3(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *VersionInfo) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("language", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:language: ", p), err) } if err := oprot.WriteString(string(p.Language)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.language (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:language: ", p), err) } return err } func (p *VersionInfo) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("language_version", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:language_version: ", p), err) } if err := oprot.WriteString(string(p.LanguageVersion)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.language_version (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:language_version: ", p), err) } return err } func (p *VersionInfo) writeField3(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("version", thrift.STRING, 3); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:version: ", p), err) } if err := oprot.WriteString(string(p.Version)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.version (3) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 3:version: ", p), err) } return err } func (p *VersionInfo) String() string { if p == nil { return "" } return fmt.Sprintf("VersionInfo(%+v)", *p) } ================================================ FILE: thrift/gen-go/test/constants.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package test import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal func init() { } ================================================ FILE: thrift/gen-go/test/meta.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package test import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type Meta interface { Health() (r *HealthStatus, err error) } type MetaClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewMetaClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *MetaClient { return &MetaClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewMetaClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *MetaClient { return &MetaClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } func (p *MetaClient) Health() (r *HealthStatus, err error) { if err = p.sendHealth(); err != nil { return } return p.recvHealth() } func (p *MetaClient) sendHealth() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("health", thrift.CALL, p.SeqId); err != nil { return } args := MetaHealthArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *MetaClient) recvHealth() (value *HealthStatus, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "health" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "health failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "health failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error19 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error20 error error20, err = error19.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error20 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "health failed: invalid message type") return } result := MetaHealthResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } type MetaProcessor struct { processorMap map[string]thrift.TProcessorFunction handler Meta } func (p *MetaProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *MetaProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *MetaProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewMetaProcessor(handler Meta) *MetaProcessor { self21 := &MetaProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self21.processorMap["health"] = &metaProcessorHealth{handler: handler} return self21 } func (p *MetaProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x22 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x22.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x22 } type metaProcessorHealth struct { handler Meta } func (p *metaProcessorHealth) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := MetaHealthArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("health", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := MetaHealthResult{} var retval *HealthStatus var err2 error if retval, err2 = p.handler.Health(); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing health: "+err2.Error()) oprot.WriteMessageBegin("health", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("health", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES type MetaHealthArgs struct { } func NewMetaHealthArgs() *MetaHealthArgs { return &MetaHealthArgs{} } func (p *MetaHealthArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaHealthArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("health_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaHealthArgs) String() string { if p == nil { return "" } return fmt.Sprintf("MetaHealthArgs(%+v)", *p) } // Attributes: // - Success type MetaHealthResult struct { Success *HealthStatus `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewMetaHealthResult() *MetaHealthResult { return &MetaHealthResult{} } var MetaHealthResult_Success_DEFAULT *HealthStatus func (p *MetaHealthResult) GetSuccess() *HealthStatus { if !p.IsSetSuccess() { return MetaHealthResult_Success_DEFAULT } return p.Success } func (p *MetaHealthResult) IsSetSuccess() bool { return p.Success != nil } func (p *MetaHealthResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *MetaHealthResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &HealthStatus{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *MetaHealthResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("health_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *MetaHealthResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *MetaHealthResult) String() string { if p == nil { return "" } return fmt.Sprintf("MetaHealthResult(%+v)", *p) } ================================================ FILE: thrift/gen-go/test/secondservice.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package test import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type SecondService interface { // Parameters: // - Arg Echo(arg string) (r string, err error) } type SecondServiceClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewSecondServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *SecondServiceClient { return &SecondServiceClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewSecondServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *SecondServiceClient { return &SecondServiceClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } // Parameters: // - Arg func (p *SecondServiceClient) Echo(arg string) (r string, err error) { if err = p.sendEcho(arg); err != nil { return } return p.recvEcho() } func (p *SecondServiceClient) sendEcho(arg string) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Echo", thrift.CALL, p.SeqId); err != nil { return } args := SecondServiceEchoArgs{ Arg: arg, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *SecondServiceClient) recvEcho() (value string, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Echo" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Echo failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Echo failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error14 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error15 error error15, err = error14.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error15 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Echo failed: invalid message type") return } result := SecondServiceEchoResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } type SecondServiceProcessor struct { processorMap map[string]thrift.TProcessorFunction handler SecondService } func (p *SecondServiceProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *SecondServiceProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *SecondServiceProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewSecondServiceProcessor(handler SecondService) *SecondServiceProcessor { self16 := &SecondServiceProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self16.processorMap["Echo"] = &secondServiceProcessorEcho{handler: handler} return self16 } func (p *SecondServiceProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x17 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x17.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x17 } type secondServiceProcessorEcho struct { handler SecondService } func (p *secondServiceProcessorEcho) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := SecondServiceEchoArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Echo", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := SecondServiceEchoResult{} var retval string var err2 error if retval, err2 = p.handler.Echo(args.Arg); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Echo: "+err2.Error()) oprot.WriteMessageBegin("Echo", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = &retval } if err2 = oprot.WriteMessageBegin("Echo", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES // Attributes: // - Arg type SecondServiceEchoArgs struct { Arg string `thrift:"arg,1" db:"arg" json:"arg"` } func NewSecondServiceEchoArgs() *SecondServiceEchoArgs { return &SecondServiceEchoArgs{} } func (p *SecondServiceEchoArgs) GetArg() string { return p.Arg } func (p *SecondServiceEchoArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SecondServiceEchoArgs) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Arg = v } return nil } func (p *SecondServiceEchoArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Echo_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SecondServiceEchoArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("arg", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:arg: ", p), err) } if err := oprot.WriteString(string(p.Arg)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.arg (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:arg: ", p), err) } return err } func (p *SecondServiceEchoArgs) String() string { if p == nil { return "" } return fmt.Sprintf("SecondServiceEchoArgs(%+v)", *p) } // Attributes: // - Success type SecondServiceEchoResult struct { Success *string `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewSecondServiceEchoResult() *SecondServiceEchoResult { return &SecondServiceEchoResult{} } var SecondServiceEchoResult_Success_DEFAULT string func (p *SecondServiceEchoResult) GetSuccess() string { if !p.IsSetSuccess() { return SecondServiceEchoResult_Success_DEFAULT } return *p.Success } func (p *SecondServiceEchoResult) IsSetSuccess() bool { return p.Success != nil } func (p *SecondServiceEchoResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SecondServiceEchoResult) ReadField0(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 0: ", err) } else { p.Success = &v } return nil } func (p *SecondServiceEchoResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Echo_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SecondServiceEchoResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRING, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := oprot.WriteString(string(*p.Success)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.success (0) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *SecondServiceEchoResult) String() string { if p == nil { return "" } return fmt.Sprintf("SecondServiceEchoResult(%+v)", *p) } ================================================ FILE: thrift/gen-go/test/simpleservice.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package test import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal type SimpleService interface { // Parameters: // - Arg Call(arg *Data) (r *Data, err error) Simple() (err error) SimpleFuture() (err error) } type SimpleServiceClient struct { Transport thrift.TTransport ProtocolFactory thrift.TProtocolFactory InputProtocol thrift.TProtocol OutputProtocol thrift.TProtocol SeqId int32 } func NewSimpleServiceClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *SimpleServiceClient { return &SimpleServiceClient{Transport: t, ProtocolFactory: f, InputProtocol: f.GetProtocol(t), OutputProtocol: f.GetProtocol(t), SeqId: 0, } } func NewSimpleServiceClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *SimpleServiceClient { return &SimpleServiceClient{Transport: t, ProtocolFactory: nil, InputProtocol: iprot, OutputProtocol: oprot, SeqId: 0, } } // Parameters: // - Arg func (p *SimpleServiceClient) Call(arg *Data) (r *Data, err error) { if err = p.sendCall(arg); err != nil { return } return p.recvCall() } func (p *SimpleServiceClient) sendCall(arg *Data) (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Call", thrift.CALL, p.SeqId); err != nil { return } args := SimpleServiceCallArgs{ Arg: arg, } if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *SimpleServiceClient) recvCall() (value *Data, err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Call" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Call failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Call failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error0 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error1 error error1, err = error0.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error1 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Call failed: invalid message type") return } result := SimpleServiceCallResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } value = result.GetSuccess() return } func (p *SimpleServiceClient) Simple() (err error) { if err = p.sendSimple(); err != nil { return } return p.recvSimple() } func (p *SimpleServiceClient) sendSimple() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("Simple", thrift.CALL, p.SeqId); err != nil { return } args := SimpleServiceSimpleArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *SimpleServiceClient) recvSimple() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "Simple" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "Simple failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "Simple failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error2 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error3 error error3, err = error2.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error3 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "Simple failed: invalid message type") return } result := SimpleServiceSimpleResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } if result.SimpleErr != nil { err = result.SimpleErr return } return } func (p *SimpleServiceClient) SimpleFuture() (err error) { if err = p.sendSimpleFuture(); err != nil { return } return p.recvSimpleFuture() } func (p *SimpleServiceClient) sendSimpleFuture() (err error) { oprot := p.OutputProtocol if oprot == nil { oprot = p.ProtocolFactory.GetProtocol(p.Transport) p.OutputProtocol = oprot } p.SeqId++ if err = oprot.WriteMessageBegin("SimpleFuture", thrift.CALL, p.SeqId); err != nil { return } args := SimpleServiceSimpleFutureArgs{} if err = args.Write(oprot); err != nil { return } if err = oprot.WriteMessageEnd(); err != nil { return } return oprot.Flush() } func (p *SimpleServiceClient) recvSimpleFuture() (err error) { iprot := p.InputProtocol if iprot == nil { iprot = p.ProtocolFactory.GetProtocol(p.Transport) p.InputProtocol = iprot } method, mTypeId, seqId, err := iprot.ReadMessageBegin() if err != nil { return } if method != "SimpleFuture" { err = thrift.NewTApplicationException(thrift.WRONG_METHOD_NAME, "SimpleFuture failed: wrong method name") return } if p.SeqId != seqId { err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, "SimpleFuture failed: out of sequence response") return } if mTypeId == thrift.EXCEPTION { error4 := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "Unknown Exception") var error5 error error5, err = error4.Read(iprot) if err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } err = error5 return } if mTypeId != thrift.REPLY { err = thrift.NewTApplicationException(thrift.INVALID_MESSAGE_TYPE_EXCEPTION, "SimpleFuture failed: invalid message type") return } result := SimpleServiceSimpleFutureResult{} if err = result.Read(iprot); err != nil { return } if err = iprot.ReadMessageEnd(); err != nil { return } if result.SimpleErr != nil { err = result.SimpleErr return } else if result.NewErr_ != nil { err = result.NewErr_ return } return } type SimpleServiceProcessor struct { processorMap map[string]thrift.TProcessorFunction handler SimpleService } func (p *SimpleServiceProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) { p.processorMap[key] = processor } func (p *SimpleServiceProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) { processor, ok = p.processorMap[key] return processor, ok } func (p *SimpleServiceProcessor) ProcessorMap() map[string]thrift.TProcessorFunction { return p.processorMap } func NewSimpleServiceProcessor(handler SimpleService) *SimpleServiceProcessor { self6 := &SimpleServiceProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)} self6.processorMap["Call"] = &simpleServiceProcessorCall{handler: handler} self6.processorMap["Simple"] = &simpleServiceProcessorSimple{handler: handler} self6.processorMap["SimpleFuture"] = &simpleServiceProcessorSimpleFuture{handler: handler} return self6 } func (p *SimpleServiceProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { name, _, seqId, err := iprot.ReadMessageBegin() if err != nil { return false, err } if processor, ok := p.GetProcessorFunction(name); ok { return processor.Process(seqId, iprot, oprot) } iprot.Skip(thrift.STRUCT) iprot.ReadMessageEnd() x7 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name) oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId) x7.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, x7 } type simpleServiceProcessorCall struct { handler SimpleService } func (p *simpleServiceProcessorCall) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := SimpleServiceCallArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Call", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := SimpleServiceCallResult{} var retval *Data var err2 error if retval, err2 = p.handler.Call(args.Arg); err2 != nil { x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Call: "+err2.Error()) oprot.WriteMessageBegin("Call", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } else { result.Success = retval } if err2 = oprot.WriteMessageBegin("Call", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type simpleServiceProcessorSimple struct { handler SimpleService } func (p *simpleServiceProcessorSimple) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := SimpleServiceSimpleArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("Simple", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := SimpleServiceSimpleResult{} var err2 error if err2 = p.handler.Simple(); err2 != nil { switch v := err2.(type) { case *SimpleErr: result.SimpleErr = v default: x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing Simple: "+err2.Error()) oprot.WriteMessageBegin("Simple", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } } if err2 = oprot.WriteMessageBegin("Simple", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } type simpleServiceProcessorSimpleFuture struct { handler SimpleService } func (p *simpleServiceProcessorSimpleFuture) Process(seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) { args := SimpleServiceSimpleFutureArgs{} if err = args.Read(iprot); err != nil { iprot.ReadMessageEnd() x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error()) oprot.WriteMessageBegin("SimpleFuture", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return false, err } iprot.ReadMessageEnd() result := SimpleServiceSimpleFutureResult{} var err2 error if err2 = p.handler.SimpleFuture(); err2 != nil { switch v := err2.(type) { case *SimpleErr: result.SimpleErr = v case *NewErr_: result.NewErr_ = v default: x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "Internal error processing SimpleFuture: "+err2.Error()) oprot.WriteMessageBegin("SimpleFuture", thrift.EXCEPTION, seqId) x.Write(oprot) oprot.WriteMessageEnd() oprot.Flush() return true, err2 } } if err2 = oprot.WriteMessageBegin("SimpleFuture", thrift.REPLY, seqId); err2 != nil { err = err2 } if err2 = result.Write(oprot); err == nil && err2 != nil { err = err2 } if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil { err = err2 } if err2 = oprot.Flush(); err == nil && err2 != nil { err = err2 } if err != nil { return } return true, err } // HELPER FUNCTIONS AND STRUCTURES // Attributes: // - Arg type SimpleServiceCallArgs struct { Arg *Data `thrift:"arg,1" db:"arg" json:"arg"` } func NewSimpleServiceCallArgs() *SimpleServiceCallArgs { return &SimpleServiceCallArgs{} } var SimpleServiceCallArgs_Arg_DEFAULT *Data func (p *SimpleServiceCallArgs) GetArg() *Data { if !p.IsSetArg() { return SimpleServiceCallArgs_Arg_DEFAULT } return p.Arg } func (p *SimpleServiceCallArgs) IsSetArg() bool { return p.Arg != nil } func (p *SimpleServiceCallArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleServiceCallArgs) ReadField1(iprot thrift.TProtocol) error { p.Arg = &Data{} if err := p.Arg.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Arg), err) } return nil } func (p *SimpleServiceCallArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Call_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleServiceCallArgs) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("arg", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:arg: ", p), err) } if err := p.Arg.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Arg), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:arg: ", p), err) } return err } func (p *SimpleServiceCallArgs) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleServiceCallArgs(%+v)", *p) } // Attributes: // - Success type SimpleServiceCallResult struct { Success *Data `thrift:"success,0" db:"success" json:"success,omitempty"` } func NewSimpleServiceCallResult() *SimpleServiceCallResult { return &SimpleServiceCallResult{} } var SimpleServiceCallResult_Success_DEFAULT *Data func (p *SimpleServiceCallResult) GetSuccess() *Data { if !p.IsSetSuccess() { return SimpleServiceCallResult_Success_DEFAULT } return p.Success } func (p *SimpleServiceCallResult) IsSetSuccess() bool { return p.Success != nil } func (p *SimpleServiceCallResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 0: if err := p.ReadField0(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleServiceCallResult) ReadField0(iprot thrift.TProtocol) error { p.Success = &Data{} if err := p.Success.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Success), err) } return nil } func (p *SimpleServiceCallResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Call_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField0(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleServiceCallResult) writeField0(oprot thrift.TProtocol) (err error) { if p.IsSetSuccess() { if err := oprot.WriteFieldBegin("success", thrift.STRUCT, 0); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 0:success: ", p), err) } if err := p.Success.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Success), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 0:success: ", p), err) } } return err } func (p *SimpleServiceCallResult) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleServiceCallResult(%+v)", *p) } type SimpleServiceSimpleArgs struct { } func NewSimpleServiceSimpleArgs() *SimpleServiceSimpleArgs { return &SimpleServiceSimpleArgs{} } func (p *SimpleServiceSimpleArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleServiceSimpleArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Simple_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleServiceSimpleArgs) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleServiceSimpleArgs(%+v)", *p) } // Attributes: // - SimpleErr type SimpleServiceSimpleResult struct { SimpleErr *SimpleErr `thrift:"simpleErr,1" db:"simpleErr" json:"simpleErr,omitempty"` } func NewSimpleServiceSimpleResult() *SimpleServiceSimpleResult { return &SimpleServiceSimpleResult{} } var SimpleServiceSimpleResult_SimpleErr_DEFAULT *SimpleErr func (p *SimpleServiceSimpleResult) GetSimpleErr() *SimpleErr { if !p.IsSetSimpleErr() { return SimpleServiceSimpleResult_SimpleErr_DEFAULT } return p.SimpleErr } func (p *SimpleServiceSimpleResult) IsSetSimpleErr() bool { return p.SimpleErr != nil } func (p *SimpleServiceSimpleResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleServiceSimpleResult) ReadField1(iprot thrift.TProtocol) error { p.SimpleErr = &SimpleErr{} if err := p.SimpleErr.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.SimpleErr), err) } return nil } func (p *SimpleServiceSimpleResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Simple_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleServiceSimpleResult) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetSimpleErr() { if err := oprot.WriteFieldBegin("simpleErr", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:simpleErr: ", p), err) } if err := p.SimpleErr.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.SimpleErr), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:simpleErr: ", p), err) } } return err } func (p *SimpleServiceSimpleResult) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleServiceSimpleResult(%+v)", *p) } type SimpleServiceSimpleFutureArgs struct { } func NewSimpleServiceSimpleFutureArgs() *SimpleServiceSimpleFutureArgs { return &SimpleServiceSimpleFutureArgs{} } func (p *SimpleServiceSimpleFutureArgs) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } if err := iprot.Skip(fieldTypeId); err != nil { return err } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleServiceSimpleFutureArgs) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("SimpleFuture_args"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleServiceSimpleFutureArgs) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleServiceSimpleFutureArgs(%+v)", *p) } // Attributes: // - SimpleErr // - NewErr_ type SimpleServiceSimpleFutureResult struct { SimpleErr *SimpleErr `thrift:"simpleErr,1" db:"simpleErr" json:"simpleErr,omitempty"` NewErr_ *NewErr_ `thrift:"newErr,2" db:"newErr" json:"newErr,omitempty"` } func NewSimpleServiceSimpleFutureResult() *SimpleServiceSimpleFutureResult { return &SimpleServiceSimpleFutureResult{} } var SimpleServiceSimpleFutureResult_SimpleErr_DEFAULT *SimpleErr func (p *SimpleServiceSimpleFutureResult) GetSimpleErr() *SimpleErr { if !p.IsSetSimpleErr() { return SimpleServiceSimpleFutureResult_SimpleErr_DEFAULT } return p.SimpleErr } var SimpleServiceSimpleFutureResult_NewErr__DEFAULT *NewErr_ func (p *SimpleServiceSimpleFutureResult) GetNewErr_() *NewErr_ { if !p.IsSetNewErr_() { return SimpleServiceSimpleFutureResult_NewErr__DEFAULT } return p.NewErr_ } func (p *SimpleServiceSimpleFutureResult) IsSetSimpleErr() bool { return p.SimpleErr != nil } func (p *SimpleServiceSimpleFutureResult) IsSetNewErr_() bool { return p.NewErr_ != nil } func (p *SimpleServiceSimpleFutureResult) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } case 2: if err := p.ReadField2(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleServiceSimpleFutureResult) ReadField1(iprot thrift.TProtocol) error { p.SimpleErr = &SimpleErr{} if err := p.SimpleErr.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.SimpleErr), err) } return nil } func (p *SimpleServiceSimpleFutureResult) ReadField2(iprot thrift.TProtocol) error { p.NewErr_ = &NewErr_{} if err := p.NewErr_.Read(iprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.NewErr_), err) } return nil } func (p *SimpleServiceSimpleFutureResult) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("SimpleFuture_result"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleServiceSimpleFutureResult) writeField1(oprot thrift.TProtocol) (err error) { if p.IsSetSimpleErr() { if err := oprot.WriteFieldBegin("simpleErr", thrift.STRUCT, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:simpleErr: ", p), err) } if err := p.SimpleErr.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.SimpleErr), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:simpleErr: ", p), err) } } return err } func (p *SimpleServiceSimpleFutureResult) writeField2(oprot thrift.TProtocol) (err error) { if p.IsSetNewErr_() { if err := oprot.WriteFieldBegin("newErr", thrift.STRUCT, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:newErr: ", p), err) } if err := p.NewErr_.Write(oprot); err != nil { return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.NewErr_), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:newErr: ", p), err) } } return err } func (p *SimpleServiceSimpleFutureResult) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleServiceSimpleFutureResult(%+v)", *p) } ================================================ FILE: thrift/gen-go/test/tchan-test.go ================================================ // @generated Code generated by thrift-gen. Do not modify. // Package test is generated code used to make or handle TChannel calls using Thrift. package test import ( "fmt" athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "github.com/uber/tchannel-go/thrift" ) // Interfaces for the service and client for the services defined in the IDL. // TChanMeta is the interface that defines the server handler and client interface. type TChanMeta interface { Health(ctx thrift.Context) (*HealthStatus, error) } // TChanSecondService is the interface that defines the server handler and client interface. type TChanSecondService interface { Echo(ctx thrift.Context, arg string) (string, error) } // TChanSimpleService is the interface that defines the server handler and client interface. type TChanSimpleService interface { Call(ctx thrift.Context, arg *Data) (*Data, error) Simple(ctx thrift.Context) error SimpleFuture(ctx thrift.Context) error } // Implementation of a client and service handler. type tchanMetaClient struct { thriftService string client thrift.TChanClient } func NewTChanMetaInheritedClient(thriftService string, client thrift.TChanClient) *tchanMetaClient { return &tchanMetaClient{ thriftService, client, } } // NewTChanMetaClient creates a client that can be used to make remote calls. func NewTChanMetaClient(client thrift.TChanClient) TChanMeta { return NewTChanMetaInheritedClient("Meta", client) } func (c *tchanMetaClient) Health(ctx thrift.Context) (*HealthStatus, error) { var resp MetaHealthResult args := MetaHealthArgs{} success, err := c.client.Call(ctx, c.thriftService, "health", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for health") } } return resp.GetSuccess(), err } type tchanMetaServer struct { handler TChanMeta } // NewTChanMetaServer wraps a handler for TChanMeta so it can be // registered with a thrift.Server. func NewTChanMetaServer(handler TChanMeta) thrift.TChanServer { return &tchanMetaServer{ handler, } } func (s *tchanMetaServer) Service() string { return "Meta" } func (s *tchanMetaServer) Methods() []string { return []string{ "health", } } func (s *tchanMetaServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "health": return s.handleHealth(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanMetaServer) handleHealth(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req MetaHealthArgs var res MetaHealthResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Health(ctx) if err != nil { return false, nil, err } else { res.Success = r } return err == nil, &res, nil } type tchanSecondServiceClient struct { thriftService string client thrift.TChanClient } func NewTChanSecondServiceInheritedClient(thriftService string, client thrift.TChanClient) *tchanSecondServiceClient { return &tchanSecondServiceClient{ thriftService, client, } } // NewTChanSecondServiceClient creates a client that can be used to make remote calls. func NewTChanSecondServiceClient(client thrift.TChanClient) TChanSecondService { return NewTChanSecondServiceInheritedClient("SecondService", client) } func (c *tchanSecondServiceClient) Echo(ctx thrift.Context, arg string) (string, error) { var resp SecondServiceEchoResult args := SecondServiceEchoArgs{ Arg: arg, } success, err := c.client.Call(ctx, c.thriftService, "Echo", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for Echo") } } return resp.GetSuccess(), err } type tchanSecondServiceServer struct { handler TChanSecondService } // NewTChanSecondServiceServer wraps a handler for TChanSecondService so it can be // registered with a thrift.Server. func NewTChanSecondServiceServer(handler TChanSecondService) thrift.TChanServer { return &tchanSecondServiceServer{ handler, } } func (s *tchanSecondServiceServer) Service() string { return "SecondService" } func (s *tchanSecondServiceServer) Methods() []string { return []string{ "Echo", } } func (s *tchanSecondServiceServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "Echo": return s.handleEcho(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanSecondServiceServer) handleEcho(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req SecondServiceEchoArgs var res SecondServiceEchoResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Echo(ctx, req.Arg) if err != nil { return false, nil, err } else { res.Success = &r } return err == nil, &res, nil } type tchanSimpleServiceClient struct { thriftService string client thrift.TChanClient } func NewTChanSimpleServiceInheritedClient(thriftService string, client thrift.TChanClient) *tchanSimpleServiceClient { return &tchanSimpleServiceClient{ thriftService, client, } } // NewTChanSimpleServiceClient creates a client that can be used to make remote calls. func NewTChanSimpleServiceClient(client thrift.TChanClient) TChanSimpleService { return NewTChanSimpleServiceInheritedClient("SimpleService", client) } func (c *tchanSimpleServiceClient) Call(ctx thrift.Context, arg *Data) (*Data, error) { var resp SimpleServiceCallResult args := SimpleServiceCallArgs{ Arg: arg, } success, err := c.client.Call(ctx, c.thriftService, "Call", &args, &resp) if err == nil && !success { switch { default: err = fmt.Errorf("received no result or unknown exception for Call") } } return resp.GetSuccess(), err } func (c *tchanSimpleServiceClient) Simple(ctx thrift.Context) error { var resp SimpleServiceSimpleResult args := SimpleServiceSimpleArgs{} success, err := c.client.Call(ctx, c.thriftService, "Simple", &args, &resp) if err == nil && !success { switch { case resp.SimpleErr != nil: err = resp.SimpleErr default: err = fmt.Errorf("received no result or unknown exception for Simple") } } return err } func (c *tchanSimpleServiceClient) SimpleFuture(ctx thrift.Context) error { var resp SimpleServiceSimpleFutureResult args := SimpleServiceSimpleFutureArgs{} success, err := c.client.Call(ctx, c.thriftService, "SimpleFuture", &args, &resp) if err == nil && !success { switch { case resp.SimpleErr != nil: err = resp.SimpleErr case resp.NewErr_ != nil: err = resp.NewErr_ default: err = fmt.Errorf("received no result or unknown exception for SimpleFuture") } } return err } type tchanSimpleServiceServer struct { handler TChanSimpleService } // NewTChanSimpleServiceServer wraps a handler for TChanSimpleService so it can be // registered with a thrift.Server. func NewTChanSimpleServiceServer(handler TChanSimpleService) thrift.TChanServer { return &tchanSimpleServiceServer{ handler, } } func (s *tchanSimpleServiceServer) Service() string { return "SimpleService" } func (s *tchanSimpleServiceServer) Methods() []string { return []string{ "Call", "Simple", "SimpleFuture", } } func (s *tchanSimpleServiceServer) Handle(ctx thrift.Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "Call": return s.handleCall(ctx, protocol) case "Simple": return s.handleSimple(ctx, protocol) case "SimpleFuture": return s.handleSimpleFuture(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanSimpleServiceServer) handleCall(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req SimpleServiceCallArgs var res SimpleServiceCallResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Call(ctx, req.Arg) if err != nil { return false, nil, err } else { res.Success = r } return err == nil, &res, nil } func (s *tchanSimpleServiceServer) handleSimple(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req SimpleServiceSimpleArgs var res SimpleServiceSimpleResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.Simple(ctx) if err != nil { switch v := err.(type) { case *SimpleErr: if v == nil { return false, nil, fmt.Errorf("Handler for simpleErr returned non-nil error type *SimpleErr but nil value") } res.SimpleErr = v default: return false, nil, err } } else { } return err == nil, &res, nil } func (s *tchanSimpleServiceServer) handleSimpleFuture(ctx thrift.Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req SimpleServiceSimpleFutureArgs var res SimpleServiceSimpleFutureResult if err := req.Read(protocol); err != nil { return false, nil, err } err := s.handler.SimpleFuture(ctx) if err != nil { switch v := err.(type) { case *SimpleErr: if v == nil { return false, nil, fmt.Errorf("Handler for simpleErr returned non-nil error type *SimpleErr but nil value") } res.SimpleErr = v case *NewErr_: if v == nil { return false, nil, fmt.Errorf("Handler for newErr returned non-nil error type *NewErr_ but nil value") } res.NewErr_ = v default: return false, nil, err } } else { } return err == nil, &res, nil } ================================================ FILE: thrift/gen-go/test/ttypes.go ================================================ // Autogenerated by Thrift Compiler (1.0.0-dev) // DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING package test import ( "bytes" "fmt" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // (needed to ensure safety because of naive import list construction.) var _ = thrift.ZERO var _ = fmt.Printf var _ = bytes.Equal var GoUnusedProtection__ int // Attributes: // - B1 // - S2 // - I3 type Data struct { B1 bool `thrift:"b1,1,required" db:"b1" json:"b1"` S2 string `thrift:"s2,2,required" db:"s2" json:"s2"` I3 int32 `thrift:"i3,3,required" db:"i3" json:"i3"` } func NewData() *Data { return &Data{} } func (p *Data) GetB1() bool { return p.B1 } func (p *Data) GetS2() string { return p.S2 } func (p *Data) GetI3() int32 { return p.I3 } func (p *Data) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetB1 bool = false var issetS2 bool = false var issetI3 bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetB1 = true case 2: if err := p.ReadField2(iprot); err != nil { return err } issetS2 = true case 3: if err := p.ReadField3(iprot); err != nil { return err } issetI3 = true default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetB1 { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field B1 is not set")) } if !issetS2 { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field S2 is not set")) } if !issetI3 { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field I3 is not set")) } return nil } func (p *Data) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadBool(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.B1 = v } return nil } func (p *Data) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.S2 = v } return nil } func (p *Data) ReadField3(iprot thrift.TProtocol) error { if v, err := iprot.ReadI32(); err != nil { return thrift.PrependError("error reading field 3: ", err) } else { p.I3 = v } return nil } func (p *Data) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("Data"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := p.writeField3(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *Data) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("b1", thrift.BOOL, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:b1: ", p), err) } if err := oprot.WriteBool(bool(p.B1)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.b1 (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:b1: ", p), err) } return err } func (p *Data) writeField2(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("s2", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:s2: ", p), err) } if err := oprot.WriteString(string(p.S2)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.s2 (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:s2: ", p), err) } return err } func (p *Data) writeField3(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("i3", thrift.I32, 3); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 3:i3: ", p), err) } if err := oprot.WriteI32(int32(p.I3)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.i3 (3) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 3:i3: ", p), err) } return err } func (p *Data) String() string { if p == nil { return "" } return fmt.Sprintf("Data(%+v)", *p) } // Attributes: // - Message type SimpleErr struct { Message string `thrift:"message,1" db:"message" json:"message"` } func NewSimpleErr() *SimpleErr { return &SimpleErr{} } func (p *SimpleErr) GetMessage() string { return p.Message } func (p *SimpleErr) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *SimpleErr) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Message = v } return nil } func (p *SimpleErr) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("SimpleErr"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *SimpleErr) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("message", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:message: ", p), err) } if err := oprot.WriteString(string(p.Message)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.message (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:message: ", p), err) } return err } func (p *SimpleErr) String() string { if p == nil { return "" } return fmt.Sprintf("SimpleErr(%+v)", *p) } func (p *SimpleErr) Error() string { return p.String() } // Attributes: // - Message type NewErr_ struct { Message string `thrift:"message,1" db:"message" json:"message"` } func NewNewErr_() *NewErr_ { return &NewErr_{} } func (p *NewErr_) GetMessage() string { return p.Message } func (p *NewErr_) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } func (p *NewErr_) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Message = v } return nil } func (p *NewErr_) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("NewErr"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *NewErr_) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("message", thrift.STRING, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:message: ", p), err) } if err := oprot.WriteString(string(p.Message)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.message (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:message: ", p), err) } return err } func (p *NewErr_) String() string { if p == nil { return "" } return fmt.Sprintf("NewErr_(%+v)", *p) } func (p *NewErr_) Error() string { return p.String() } // Attributes: // - Ok // - Message type HealthStatus struct { Ok bool `thrift:"ok,1,required" db:"ok" json:"ok"` Message *string `thrift:"message,2" db:"message" json:"message,omitempty"` } func NewHealthStatus() *HealthStatus { return &HealthStatus{} } func (p *HealthStatus) GetOk() bool { return p.Ok } var HealthStatus_Message_DEFAULT string func (p *HealthStatus) GetMessage() string { if !p.IsSetMessage() { return HealthStatus_Message_DEFAULT } return *p.Message } func (p *HealthStatus) IsSetMessage() bool { return p.Message != nil } func (p *HealthStatus) Read(iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) } var issetOk bool = false for { _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() if err != nil { return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } if fieldTypeId == thrift.STOP { break } switch fieldId { case 1: if err := p.ReadField1(iprot); err != nil { return err } issetOk = true case 2: if err := p.ReadField2(iprot); err != nil { return err } default: if err := iprot.Skip(fieldTypeId); err != nil { return err } } if err := iprot.ReadFieldEnd(); err != nil { return err } } if err := iprot.ReadStructEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } if !issetOk { return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("Required field Ok is not set")) } return nil } func (p *HealthStatus) ReadField1(iprot thrift.TProtocol) error { if v, err := iprot.ReadBool(); err != nil { return thrift.PrependError("error reading field 1: ", err) } else { p.Ok = v } return nil } func (p *HealthStatus) ReadField2(iprot thrift.TProtocol) error { if v, err := iprot.ReadString(); err != nil { return thrift.PrependError("error reading field 2: ", err) } else { p.Message = &v } return nil } func (p *HealthStatus) Write(oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin("HealthStatus"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } if err := p.writeField1(oprot); err != nil { return err } if err := p.writeField2(oprot); err != nil { return err } if err := oprot.WriteFieldStop(); err != nil { return thrift.PrependError("write field stop error: ", err) } if err := oprot.WriteStructEnd(); err != nil { return thrift.PrependError("write struct stop error: ", err) } return nil } func (p *HealthStatus) writeField1(oprot thrift.TProtocol) (err error) { if err := oprot.WriteFieldBegin("ok", thrift.BOOL, 1); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:ok: ", p), err) } if err := oprot.WriteBool(bool(p.Ok)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.ok (1) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 1:ok: ", p), err) } return err } func (p *HealthStatus) writeField2(oprot thrift.TProtocol) (err error) { if p.IsSetMessage() { if err := oprot.WriteFieldBegin("message", thrift.STRING, 2); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field begin error 2:message: ", p), err) } if err := oprot.WriteString(string(*p.Message)); err != nil { return thrift.PrependError(fmt.Sprintf("%T.message (2) field write error: ", p), err) } if err := oprot.WriteFieldEnd(); err != nil { return thrift.PrependError(fmt.Sprintf("%T write field end error 2:message: ", p), err) } } return err } func (p *HealthStatus) String() string { if p == nil { return "" } return fmt.Sprintf("HealthStatus(%+v)", *p) } ================================================ FILE: thrift/headers.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "fmt" "io" "github.com/uber/tchannel-go/typed" ) // WriteHeaders writes the given key-value pairs using the following encoding: // len~2 (k~4 v~4)~len func WriteHeaders(w io.Writer, headers map[string]string) error { // TODO(prashant): Since we are not writing length-prefixed data here, // we can write out to the buffer, and if it fills up, flush it. // Right now, we calculate the size of the required buffer and write it out. // Calculate the size of the buffer that we need. size := 2 for k, v := range headers { size += 4 /* size of key/value lengths */ size += len(k) + len(v) } buf := make([]byte, size) writeBuffer := typed.NewWriteBuffer(buf) writeBuffer.WriteUint16(uint16(len(headers))) for k, v := range headers { writeBuffer.WriteLen16String(k) writeBuffer.WriteLen16String(v) } if err := writeBuffer.Err(); err != nil { return err } // Safety check to ensure the bytes written calculation is correct. if writeBuffer.BytesWritten() != size { return fmt.Errorf( "writeHeaders size calculation wrong, expected to write %v bytes, only wrote %v bytes", size, writeBuffer.BytesWritten()) } _, err := writeBuffer.FlushTo(w) return err } func readHeaders(reader *typed.Reader) (map[string]string, error) { numHeaders := reader.ReadUint16() if numHeaders == 0 { return nil, reader.Err() } headers := make(map[string]string, numHeaders) for i := 0; i < int(numHeaders) && reader.Err() == nil; i++ { k := reader.ReadLen16String() v := reader.ReadLen16String() headers[k] = v } return headers, reader.Err() } // ReadHeaders reads key-value pairs encoded using WriteHeaders. func ReadHeaders(r io.Reader) (map[string]string, error) { reader := typed.NewReader(r) m, err := readHeaders(reader) reader.Release() return m, err } ================================================ FILE: thrift/headers_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "bytes" "io/ioutil" "testing" "testing/iotest" "github.com/stretchr/testify/assert" ) var headers = map[string]string{ "header1": "value1", "header2": "value2", "header3": "value1", "header4": "value2", "header5": "value1", "header6": "value2", "header7": "value1", "header8": "value2", "header9": "value1", "header0": "value2", } var headerTests = []struct { m map[string]string encoding []byte encoding2 []byte }{ { m: nil, encoding: []byte{0, 0}, }, { m: make(map[string]string), encoding: []byte{0, 0}, }, { m: map[string]string{ "k": "v", }, encoding: []byte{ 0, 1, /* number of headers */ 0, 1, /* length of key */ 'k', 0, 1, /* length of value */ 'v', }, }, { m: map[string]string{ "": "", }, encoding: []byte{ 0, 1, /* number of headers */ 0, 0, 0, 0, }, }, { m: map[string]string{ "k1": "v12", "k2": "v34", }, encoding: []byte{ 0, 2, /* number of headers */ 0, 2, /* length of key */ 'k', '2', 0, 3, /* length of value */ 'v', '3', '4', 0, 2, /* length of key */ 'k', '1', 0, 3, /* length of value */ 'v', '1', '2', }, encoding2: []byte{ 0, 2, /* number of headers */ 0, 2, /* length of key */ 'k', '1', 0, 3, /* length of value */ 'v', '1', '2', 0, 2, /* length of key */ 'k', '2', 0, 3, /* length of value */ 'v', '3', '4', }, }, } func TestWriteHeadersSuccessful(t *testing.T) { for _, tt := range headerTests { buf := &bytes.Buffer{} err := WriteHeaders(buf, tt.m) assert.NoError(t, err, "WriteHeaders failed") // Writes iterate over the map in an undefined order, so we might get // encoding or encoding2. If it's not encoding, assert that it's encoding2. if !bytes.Equal(tt.encoding, buf.Bytes()) { assert.Equal(t, tt.encoding2, buf.Bytes(), "Unexpected bytes") } } } func TestReadHeadersSuccessful(t *testing.T) { for _, tt := range headerTests { // when the bytes are {0, 0}, we always return nil. if tt.m != nil && len(tt.m) == 0 { continue } reader := iotest.OneByteReader(bytes.NewReader(tt.encoding)) got, err := ReadHeaders(reader) assert.NoError(t, err, "ReadHeaders failed") assert.Equal(t, tt.m, got, "Map mismatch") if tt.encoding2 != nil { reader := iotest.OneByteReader(bytes.NewReader(tt.encoding2)) got, err := ReadHeaders(reader) assert.NoError(t, err, "ReadHeaders failed") assert.Equal(t, tt.m, got, "Map mismatch") } } } func TestReadHeadersLeftoverBytes(t *testing.T) { buf := []byte{0, 0, 1, 2, 3} r := bytes.NewReader(buf) headers, err := ReadHeaders(r) assert.NoError(t, err, "ReadHeaders failed") assert.Equal(t, map[string]string(nil), headers, "Headers mismatch") leftover, err := ioutil.ReadAll(r) assert.NoError(t, err, "ReadAll failed") assert.Equal(t, []byte{1, 2, 3}, leftover, "Reader consumed leftover bytes") } func BenchmarkWriteHeaders(b *testing.B) { for i := 0; i < b.N; i++ { WriteHeaders(ioutil.Discard, headers) } } func BenchmarkReadHeaders(b *testing.B) { buf := &bytes.Buffer{} assert.NoError(b, WriteHeaders(buf, headers)) bs := buf.Bytes() reader := bytes.NewReader(bs) b.ResetTimer() for i := 0; i < b.N; i++ { reader.Seek(0, 0) ReadHeaders(reader) } } ================================================ FILE: thrift/interfaces.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" // This file defines interfaces that are used or exposed by thrift-gen generated code. // TChanClient is used by the generated code to make outgoing requests. // TChanServer is exposed by the generated code, and is called on incoming requests. // TChanClient abstracts calling a Thrift endpoint, and is used by the generated client code. type TChanClient interface { // Call should be passed the method to call and the request/response Thrift structs. Call(ctx Context, serviceName, methodName string, req, resp athrift.TStruct) (success bool, err error) } // TChanServer abstracts handling of an RPC that is implemented by the generated server code. type TChanServer interface { // Handle should read the request from the given reqReader, and return the response struct. // The arguments returned are success, result struct, unexpected error Handle(ctx Context, methodName string, protocol athrift.TProtocol) (success bool, resp athrift.TStruct, err error) // Service returns the service name. Service() string // Methods returns the method names handled by this server. Methods() []string } ================================================ FILE: thrift/meta.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "errors" "runtime" "strings" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/thrift/gen-go/meta" ) // HealthFunc is the interface for custom health endpoints. // ok is whether the service health is OK, and message is optional additional information for the health result. type HealthFunc func(ctx Context) (ok bool, message string) // HealthRequestType is the type of health check. type HealthRequestType int const ( // Process health checks are used to check whether the process is up // and should almost always return true immediately. Process HealthRequestType = iota // Traffic health checks are used to check whether the process should // receive traffic. This can be used to keep a process running, but // not receiving health checks (e.g., during process warm-up). Traffic ) // HealthRequest is optional parametres for a health request. type HealthRequest struct { // Type is the type of health check being requested. Type HealthRequestType } // HealthRequestFunc is a health check function that includes parameters // about the health check. type HealthRequestFunc func(Context, HealthRequest) (ok bool, message string) // healthHandler implements the default health check enpoint. type metaHandler struct { healthFn HealthRequestFunc } // newMetaHandler return a new HealthHandler instance. func newMetaHandler() *metaHandler { return &metaHandler{healthFn: defaultHealth} } // Health returns true as default Health endpoint. func (h *metaHandler) Health(ctx Context, req *meta.HealthRequest) (*meta.HealthStatus, error) { ok, message := h.healthFn(ctx, metaReqToReq(req)) if message == "" { return &meta.HealthStatus{Ok: ok}, nil } return &meta.HealthStatus{Ok: ok, Message: &message}, nil } func (h *metaHandler) ThriftIDL(ctx Context) (*meta.ThriftIDLs, error) { // TODO(prashant): Add thriftIDL to the generated code. return nil, errors.New("unimplemented") } func (h *metaHandler) VersionInfo(ctx Context) (*meta.VersionInfo, error) { return &meta.VersionInfo{ Language: "go", LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), Version: tchannel.VersionInfo, }, nil } func defaultHealth(ctx Context, r HealthRequest) (bool, string) { return true, "" } func (h *metaHandler) setHandler(f HealthRequestFunc) { h.healthFn = f } func metaReqToReq(r *meta.HealthRequest) HealthRequest { if r == nil { return HealthRequest{} } return HealthRequest{ Type: HealthRequestType(r.GetType()), } } ================================================ FILE: thrift/meta.thrift ================================================ // The HealthState provides additional information when the // health endpoint returns !ok. enum HealthState { REFUSING = 0, ACCEPTING = 1, STOPPING = 2, STOPPED = 3, } // The HealthRequestType is the type of health check, as a process may want to // return that it's running, but not ready for traffic. enum HealthRequestType { // PROCESS indicates that the health check is for checking that // the process is up. Handlers should always return "ok". PROCESS = 0, // TRAFFIC indicates that the health check is for checking whether // the process wants to receive traffic. The process may want to reject // traffic due to warmup, or before shutdown to avoid in-flight requests // when the process exits. TRAFFIC = 1, } struct HealthRequest { 1: optional HealthRequestType type } struct HealthStatus { 1: required bool ok 2: optional string message 3: optional HealthState state } typedef string filename struct ThriftIDLs { // map: filename -> contents 1: required map idls // the entry IDL that imports others 2: required filename entryPoint } struct VersionInfo { // short string naming the implementation language 1: required string language // language-specific version string representing runtime or build chain 2: required string language_version // semver version indicating the version of the tchannel library 3: required string version } service Meta { // All arguments are optional. The default is a PROCESS health request. HealthStatus health(1: HealthRequest hr) ThriftIDLs thriftIDL() VersionInfo versionInfo() } ================================================ FILE: thrift/meta_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "runtime" "strings" "testing" "time" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/thrift/gen-go/meta" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestThriftIDL(t *testing.T) { withMetaSetup(t, func(ctx Context, c tchanMeta, server *Server) { _, err := c.ThriftIDL(ctx) assert.Error(t, err, "Health endpoint failed") assert.Contains(t, err.Error(), "unimplemented") }) } func TestVersionInfo(t *testing.T) { withMetaSetup(t, func(ctx Context, c tchanMeta, server *Server) { ret, err := c.VersionInfo(ctx) if assert.NoError(t, err, "VersionInfo endpoint failed") { expected := &meta.VersionInfo{ Language: "go", LanguageVersion: strings.TrimPrefix(runtime.Version(), "go"), Version: tchannel.VersionInfo, } assert.Equal(t, expected, ret, "Unexpected version info") } }) } func TestHealth(t *testing.T) { tests := []struct { msg string healthFunc HealthFunc healthReqFunc HealthRequestFunc req *meta.HealthRequest wantOK bool wantMessage *string }{ { msg: "default health func", wantOK: true, }, { msg: "healthFunc returning unhealthy, no message", healthFunc: func(Context) (bool, string) { return false, "" }, wantOK: false, }, { msg: "healthFunc returning healthy, with message", healthFunc: func(Context) (bool, string) { return true, "ok" }, wantOK: true, wantMessage: stringPtr("ok"), }, { msg: "healthReqFunc returning unhealthy for traffic, default check", healthReqFunc: func(_ Context, r HealthRequest) (bool, string) { return r.Type != Traffic, "" }, wantOK: true, }, { msg: "healthReqFunc returning unhealthy for traffic, traffic check", healthReqFunc: func(_ Context, r HealthRequest) (bool, string) { return r.Type != Traffic, "" }, req: &meta.HealthRequest{Type: meta.HealthRequestTypePtr(meta.HealthRequestType_TRAFFIC)}, wantOK: false, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { withMetaSetup(t, func(ctx Context, c tchanMeta, server *Server) { if tt.healthFunc != nil { server.RegisterHealthHandler(tt.healthFunc) } if tt.healthReqFunc != nil { server.RegisterHealthRequestHandler(tt.healthReqFunc) } req := tt.req if req == nil { req = &meta.HealthRequest{} } ret, err := c.Health(ctx, req) require.NoError(t, err, "Health endpoint failed") assert.Equal(t, tt.wantOK, ret.Ok, "Health status mismatch") assert.Equal(t, tt.wantMessage, ret.Message, "Health message mismatch") }) }) } } func TestMetaReqToReq(t *testing.T) { tests := []struct { msg string r *meta.HealthRequest want HealthRequest }{ { msg: "nil", r: nil, want: HealthRequest{}, }, { msg: "default", r: &meta.HealthRequest{}, want: HealthRequest{}, }, { msg: "explcit process check", r: &meta.HealthRequest{ Type: meta.HealthRequestTypePtr(meta.HealthRequestType_PROCESS), }, want: HealthRequest{ Type: Process, }, }, { msg: "explcit traffic check", r: &meta.HealthRequest{ Type: meta.HealthRequestTypePtr(meta.HealthRequestType_TRAFFIC), }, want: HealthRequest{ Type: Traffic, }, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { assert.Equal(t, tt.want, metaReqToReq(tt.r)) }) } } func withMetaSetup(t *testing.T, f func(ctx Context, c tchanMeta, server *Server)) { ctx, cancel := NewContext(time.Second * 10) defer cancel() // Start server tchan, server := setupMetaServer(t) defer tchan.Close() // Get client1 c := getMetaClient(t, tchan.PeerInfo().HostPort) f(ctx, c, server) } func setupMetaServer(t *testing.T) (*tchannel.Channel, *Server) { tchan := testutils.NewServer(t, testutils.NewOpts().SetServiceName("meta")) server := NewServer(tchan) return tchan, server } func getMetaClient(t *testing.T, dst string) tchanMeta { tchan := testutils.NewClient(t, nil) tchan.Peers().Add(dst) thriftClient := NewClient(tchan, "meta", nil) return newTChanMetaClient(thriftClient) } func stringPtr(s string) *string { return &s } ================================================ FILE: thrift/mocks/TChanMeta.go ================================================ package mocks import "github.com/uber/tchannel-go/thrift/gen-go/meta" import "github.com/stretchr/testify/mock" import "github.com/uber/tchannel-go/thrift" type TChanMeta struct { mock.Mock } func (_m *TChanMeta) Health(ctx thrift.Context) (*meta.HealthStatus, error) { ret := _m.Called(ctx) var r0 *meta.HealthStatus if rf, ok := ret.Get(0).(func(thrift.Context) *meta.HealthStatus); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*meta.HealthStatus) } } var r1 error if rf, ok := ret.Get(1).(func(thrift.Context) error); ok { r1 = rf(ctx) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: thrift/mocks/TChanSecondService.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package mocks import "github.com/stretchr/testify/mock" import "github.com/uber/tchannel-go/thrift" type TChanSecondService struct { mock.Mock } func (_m *TChanSecondService) Echo(_ctx thrift.Context, _arg string) (string, error) { ret := _m.Called(_ctx, _arg) var r0 string if rf, ok := ret.Get(0).(func(thrift.Context, string) string); ok { r0 = rf(_ctx, _arg) } else { r0 = ret.Get(0).(string) } var r1 error if rf, ok := ret.Get(1).(func(thrift.Context, string) error); ok { r1 = rf(_ctx, _arg) } else { r1 = ret.Error(1) } return r0, r1 } ================================================ FILE: thrift/mocks/TChanSimpleService.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package mocks import "github.com/uber/tchannel-go/thrift/gen-go/test" import "github.com/stretchr/testify/mock" import "github.com/uber/tchannel-go/thrift" type TChanSimpleService struct { mock.Mock } func (_m *TChanSimpleService) Call(_ctx thrift.Context, _arg *test.Data) (*test.Data, error) { ret := _m.Called(_ctx, _arg) var r0 *test.Data if rf, ok := ret.Get(0).(func(thrift.Context, *test.Data) *test.Data); ok { r0 = rf(_ctx, _arg) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*test.Data) } } var r1 error if rf, ok := ret.Get(1).(func(thrift.Context, *test.Data) error); ok { r1 = rf(_ctx, _arg) } else { r1 = ret.Error(1) } return r0, r1 } func (_m *TChanSimpleService) Simple(_ctx thrift.Context) error { ret := _m.Called(_ctx) var r0 error if rf, ok := ret.Get(0).(func(thrift.Context) error); ok { r0 = rf(_ctx) } else { r0 = ret.Error(0) } return r0 } func (_m *TChanSimpleService) SimpleFuture(_ctx thrift.Context) error { ret := _m.Called(_ctx) var r0 error if rf, ok := ret.Get(0).(func(thrift.Context) error); ok { r0 = rf(_ctx) } else { r0 = ret.Error(0) } return r0 } ================================================ FILE: thrift/options.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "golang.org/x/net/context" ) // RegisterOption is the interface for options to Register. type RegisterOption interface { Apply(h *handler) } // PostResponseCB registers a callback that is run after a response has been // compeltely processed (e.g. written to the channel). // This gives the server a chance to clean up resources from the response object type PostResponseCB func(ctx context.Context, method string, response thrift.TStruct) type optPostResponse PostResponseCB // OptPostResponse registers a PostResponseCB. func OptPostResponse(cb PostResponseCB) RegisterOption { return optPostResponse(cb) } func (o optPostResponse) Apply(h *handler) { h.postResponseCB = PostResponseCB(o) } ================================================ FILE: thrift/server.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "log" "strings" "sync" tchannel "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/internal/argreader" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "golang.org/x/net/context" ) type handler struct { server TChanServer postResponseCB PostResponseCB } // Server handles incoming TChannel calls and forwards them to the matching TChanServer. type Server struct { sync.RWMutex ch tchannel.Registrar log tchannel.Logger handlers map[string]handler metaHandler *metaHandler ctxFn func(ctx context.Context, method string, headers map[string]string) Context } // NewServer returns a server that can serve thrift services over TChannel. func NewServer(registrar tchannel.Registrar) *Server { metaHandler := newMetaHandler() server := &Server{ ch: registrar, log: registrar.Logger(), handlers: make(map[string]handler), metaHandler: metaHandler, ctxFn: defaultContextFn, } server.Register(newTChanMetaServer(metaHandler)) if ch, ok := registrar.(*tchannel.Channel); ok { // Register the meta endpoints on the "tchannel" service name. NewServer(ch.GetSubChannel("tchannel")) } return server } // Register registers the given TChanServer to be called on any incoming call for its' services. // TODO(prashant): Replace Register call with this call. func (s *Server) Register(svr TChanServer, opts ...RegisterOption) { service := svr.Service() handler := &handler{server: svr} for _, opt := range opts { opt.Apply(handler) } s.Lock() s.handlers[service] = *handler s.Unlock() for _, m := range svr.Methods() { s.ch.Register(s, service+"::"+m) } } // RegisterHealthHandler uses the user-specified function f for the Health endpoint. func (s *Server) RegisterHealthHandler(f HealthFunc) { wrapped := func(ctx Context, r HealthRequest) (bool, string) { return f(ctx) } s.metaHandler.setHandler(wrapped) } // RegisterHealthRequestHandler uses the user-specified function for the // Health endpoint. The function receives the health request which includes // information about the type of the request being performed. func (s *Server) RegisterHealthRequestHandler(f HealthRequestFunc) { s.metaHandler.setHandler(f) } // SetContextFn sets the function used to convert a context.Context to a thrift.Context. // Note: This API may change and is only intended to bridge different contexts. func (s *Server) SetContextFn(f func(ctx context.Context, method string, headers map[string]string) Context) { s.ctxFn = f } func (s *Server) onError(call *tchannel.InboundCall, err error) { // TODO(prashant): Expose incoming call errors through options for NewServer. remotePeer := call.RemotePeer() logger := s.log.WithFields( tchannel.ErrField(err), tchannel.LogField{Key: "method", Value: call.MethodString()}, tchannel.LogField{Key: "callerName", Value: call.CallerName()}, // TODO: These are very similar to the connection fields, but we don't // have access to the connection's logger. Consider exposing the // connection through CurrentCall. tchannel.LogField{Key: "localAddr", Value: call.LocalPeer().HostPort}, tchannel.LogField{Key: "remoteHostPort", Value: remotePeer.HostPort}, tchannel.LogField{Key: "remoteIsEphemeral", Value: remotePeer.IsEphemeral}, tchannel.LogField{Key: "remoteProcess", Value: remotePeer.ProcessName}, ) if tchannel.GetSystemErrorCode(err) == tchannel.ErrCodeTimeout { logger.Debug("Thrift server timeout.") } else { logger.Error("Thrift server error.") } } func defaultContextFn(ctx context.Context, method string, headers map[string]string) Context { return WithHeaders(ctx, headers) } func (s *Server) handle(origCtx context.Context, handler handler, method string, call *tchannel.InboundCall) error { reader, err := call.Arg2Reader() if err != nil { return err } headers, err := ReadHeaders(reader) if err != nil { return err } if err := argreader.EnsureEmpty(reader, "reading request headers"); err != nil { return err } if err := reader.Close(); err != nil { return err } reader, err = call.Arg3Reader() if err != nil { return err } tracer := tchannel.TracerFromRegistrar(s.ch) origCtx = tchannel.ExtractInboundSpan(origCtx, call, headers, tracer) ctx := s.ctxFn(origCtx, method, headers) wp := getProtocolReader(reader) success, resp, err := handler.server.Handle(ctx, method, wp.protocol) thriftProtocolPool.Put(wp) if handler.postResponseCB != nil { defer handler.postResponseCB(ctx, method, resp) } if err != nil { if _, ok := err.(thrift.TProtocolException); ok { // We failed to parse the Thrift generated code, so convert the error to bad request. err = tchannel.NewSystemError(tchannel.ErrCodeBadRequest, err.Error()) } reader.Close() call.Response().SendSystemError(err) return nil } if err := argreader.EnsureEmpty(reader, "reading request body"); err != nil { return err } if err := reader.Close(); err != nil { return err } if !success { call.Response().SetApplicationError() } writer, err := call.Response().Arg2Writer() if err != nil { return err } if err := WriteHeaders(writer, ctx.ResponseHeaders()); err != nil { return err } if err := writer.Close(); err != nil { return err } writer, err = call.Response().Arg3Writer() wp = getProtocolWriter(writer) defer thriftProtocolPool.Put(wp) if err := resp.Write(wp.protocol); err != nil { call.Response().SendSystemError(err) return err } return writer.Close() } func getServiceMethod(method string) (string, string, bool) { s := string(method) sep := strings.Index(s, "::") if sep == -1 { return "", "", false } return s[:sep], s[sep+2:], true } // Handle handles an incoming TChannel call and forwards it to the correct handler. func (s *Server) Handle(ctx context.Context, call *tchannel.InboundCall) { op := call.MethodString() service, method, ok := getServiceMethod(op) if !ok { log.Fatalf("Handle got call for %s which does not match the expected call format", op) } s.RLock() handler, ok := s.handlers[service] s.RUnlock() if !ok { log.Fatalf("Handle got call for service %v which is not registered", service) } if err := s.handle(ctx, handler, method, call); err != nil { s.onError(call, err) } } ================================================ FILE: thrift/server_test.go ================================================ package thrift import ( "errors" "testing" "time" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) var errIO = errors.New("IO Error") // badTStruct implements TStruct that always fails with the provided error. type badTStruct struct { // If specified, runs the specified function before failing the Write. PreWrite func(athrift.TProtocol) Err error } func (t *badTStruct) Write(p athrift.TProtocol) error { if t.PreWrite != nil { t.PreWrite(p) } return t.Err } func (t *badTStruct) Read(p athrift.TProtocol) error { return t.Err } // nullTStruct implements TStruct that does nothing at all with no errors. type nullTStruct struct{} func (*nullTStruct) Write(p athrift.TProtocol) error { return nil } func (*nullTStruct) Read(p athrift.TProtocol) error { return nil } // thriftStruction is a TChannel service that implements the following // methods: // // destruct // Returns a TStruct that fails without writing anything. // partialDestruct // Returns a TStruct that fails after writing partial output. type thriftStruction struct{} func (ts *thriftStruction) Handle( ctx Context, methodName string, protocol athrift.TProtocol, ) (success bool, resp athrift.TStruct, err error) { var preWrite func(athrift.TProtocol) if methodName == "partialDestruct" { preWrite = func(p athrift.TProtocol) { p.WriteStructBegin("foo") p.WriteFieldBegin("bar", athrift.STRING, 42) p.WriteString("baz") } } // successful call with a TStruct that fails while writing. return true, &badTStruct{Err: errIO, PreWrite: preWrite}, nil } func (ts *thriftStruction) Service() string { return "destruct" } func (ts *thriftStruction) Methods() []string { return []string{"destruct", "partialDestruct"} } func TestHandleTStructError(t *testing.T) { serverOpts := testutils.NewOpts(). AddLogFilter( "Thrift server error.", 1, "error", "IO Error", "method", "destruct::destruct"). AddLogFilter( "Thrift server error.", 1, "error", "IO Error", "method", "destruct::partialDestruct") server := testutils.NewTestServer(t, serverOpts) defer server.CloseAndVerify() // Create a thrift server with a handler that returns success with // TStructs that refuse to do I/O. tchan := server.Server() NewServer(tchan).Register(&thriftStruction{}) client := NewClient( server.NewClient(testutils.NewOpts()), tchan.ServiceName(), &ClientOptions{HostPort: server.HostPort()}, ) t.Run("failing response", func(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() _, err := client.Call(ctx, "destruct", "destruct", &nullTStruct{}, &nullTStruct{}) assert.Error(t, err) assert.IsType(t, tchannel.SystemError{}, err) assert.Equal(t, tchannel.ErrCodeUnexpected, tchannel.GetSystemErrorCode(err)) assert.Equal(t, "IO Error", tchannel.GetSystemErrorMessage(err)) }) t.Run("failing response with partial write", func(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() _, err := client.Call(ctx, "destruct", "partialDestruct", &nullTStruct{}, &nullTStruct{}) assert.Error(t, err) assert.IsType(t, tchannel.SystemError{}, err) assert.Equal(t, tchannel.ErrCodeUnexpected, tchannel.GetSystemErrorCode(err)) assert.Equal(t, "IO Error", tchannel.GetSystemErrorMessage(err)) }) } ================================================ FILE: thrift/struct.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "io" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // WriteStruct writes the given Thrift struct to a writer. It pools TProtocols. func WriteStruct(writer io.Writer, s thrift.TStruct) error { wp := getProtocolWriter(writer) err := s.Write(wp.protocol) thriftProtocolPool.Put(wp) return err } // ReadStruct reads the given Thrift struct. It pools TProtocols. func ReadStruct(reader io.Reader, s thrift.TStruct) error { wp := getProtocolReader(reader) err := s.Read(wp.protocol) thriftProtocolPool.Put(wp) return err } ================================================ FILE: thrift/struct_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift_test import ( "bytes" "io/ioutil" "sync" "testing" . "github.com/uber/tchannel-go/thrift" "github.com/uber/tchannel-go/testutils/testreader" "github.com/uber/tchannel-go/testutils/testwriter" "github.com/uber/tchannel-go/thrift/gen-go/test" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) var structTest = struct { s thrift.TStruct encoded []byte }{ s: &test.Data{ B1: true, S2: "S2", I3: 3, }, encoded: []byte{ 0x2, // bool 0x0, 0x1, // field 1 0x1, // true 0xb, // string 0x0, 0x2, // field 2 0x0, 0x0, 0x0, 0x2, // length of string "S2" 'S', '2', // string "S2" 0x8, // i32 0x0, 0x3, // field 3 0x0, 0x0, 0x0, 0x3, // i32 3 0x0, // end of struct }, } func TestReadStruct(t *testing.T) { appendBytes := func(bs []byte, append []byte) []byte { b := make([]byte, len(bs)+len(append)) n := copy(b, bs) copy(b[n:], append) return b } tests := []struct { s thrift.TStruct encoded []byte wantErr bool leftover []byte }{ { s: structTest.s, encoded: structTest.encoded, }, { s: &test.Data{ B1: true, S2: "S2", }, // Missing field 3. encoded: structTest.encoded[:len(structTest.encoded)-8], wantErr: true, }, { s: structTest.s, encoded: appendBytes(structTest.encoded, []byte{1, 2, 3, 4}), leftover: []byte{1, 2, 3, 4}, }, } for _, tt := range tests { reader := bytes.NewReader(tt.encoded) var s thrift.TStruct = &test.Data{} err := ReadStruct(reader, s) assert.Equal(t, tt.wantErr, err != nil, "Unexpected error: %v", err) // Even if there's an error, the struct will be partially filled. assert.Equal(t, tt.s, s, "Unexpected struct") leftover, err := ioutil.ReadAll(reader) if assert.NoError(t, err, "Read leftover bytes failed") { // ReadAll always returns a non-nil byte slice. if tt.leftover == nil { tt.leftover = make([]byte, 0) } assert.Equal(t, tt.leftover, leftover, "Leftover bytes mismatch") } } } func TestReadStructErr(t *testing.T) { writer, reader := testreader.ChunkReader() writer <- structTest.encoded[:10] writer <- nil close(writer) s := &test.Data{} err := ReadStruct(reader, s) if assert.Error(t, err, "ReadStruct should fail") { // Apache Thrift just prepends the error message, and doesn't give us access // to the underlying error, so we can't check the underlying error exactly. assert.Contains(t, err.Error(), testreader.ErrUser.Error(), "Underlying error missing") } } func TestWriteStruct(t *testing.T) { tests := []struct { s thrift.TStruct encoded []byte wantErr bool }{ { s: structTest.s, encoded: structTest.encoded, }, } for _, tt := range tests { buf := &bytes.Buffer{} err := WriteStruct(buf, tt.s) assert.Equal(t, tt.wantErr, err != nil, "Unexpected err: %v", err) if err != nil { continue } assert.Equal(t, tt.encoded, buf.Bytes(), "Encoded data mismatch") } } func TestWriteStructErr(t *testing.T) { writer := testwriter.Limited(10) err := WriteStruct(writer, structTest.s) if assert.Error(t, err, "WriteStruct should fail") { // Apache Thrift just prepends the error message, and doesn't give us access // to the underlying error, so we can't check the underlying error exactly. assert.Contains(t, err.Error(), testwriter.ErrOutOfSpace.Error(), "Underlying error missing") } } func TestParallelReadWrites(t *testing.T) { var wg sync.WaitGroup testBG := func(f func(t *testing.T)) { wg.Add(1) go func() { f(t) wg.Done() }() } for i := 0; i < 50; i++ { testBG(TestReadStruct) testBG(TestWriteStruct) } wg.Wait() } func BenchmarkWriteStruct(b *testing.B) { buf := &bytes.Buffer{} for i := 0; i < b.N; i++ { buf.Reset() WriteStruct(buf, structTest.s) } } func BenchmarkReadStruct(b *testing.B) { buf := bytes.NewReader(structTest.encoded) var d test.Data buf.Seek(0, 0) assert.NoError(b, ReadStruct(buf, &d)) b.ResetTimer() for i := 0; i < b.N; i++ { buf.Seek(0, 0) ReadStruct(buf, &d) } } ================================================ FILE: thrift/tchan-meta.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "fmt" athrift "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" gen "github.com/uber/tchannel-go/thrift/gen-go/meta" ) // Interfaces for the service and client for the services defined in the IDL. // tchanMeta is the interface that defines the server handler and client interface. type tchanMeta interface { Health(ctx Context, req *gen.HealthRequest) (*gen.HealthStatus, error) ThriftIDL(ctx Context) (*gen.ThriftIDLs, error) VersionInfo(ctx Context) (*gen.VersionInfo, error) } // Implementation of a client and service handler. type tchanMetaClient struct { thriftService string client TChanClient } func newTChanMetaClient(client TChanClient) tchanMeta { return &tchanMetaClient{ "Meta", client, } } func (c *tchanMetaClient) Health(ctx Context, req *gen.HealthRequest) (*gen.HealthStatus, error) { var resp gen.MetaHealthResult args := gen.MetaHealthArgs{ Hr: req, } success, err := c.client.Call(ctx, c.thriftService, "health", &args, &resp) if err == nil && !success { } return resp.GetSuccess(), err } func (c *tchanMetaClient) ThriftIDL(ctx Context) (*gen.ThriftIDLs, error) { var resp gen.MetaThriftIDLResult args := gen.MetaThriftIDLArgs{} success, err := c.client.Call(ctx, c.thriftService, "thriftIDL", &args, &resp) if err == nil && !success { } return resp.GetSuccess(), err } func (c *tchanMetaClient) VersionInfo(ctx Context) (*gen.VersionInfo, error) { var resp gen.MetaVersionInfoResult args := gen.MetaVersionInfoArgs{} success, err := c.client.Call(ctx, c.thriftService, "versionInfo", &args, &resp) if err == nil && !success { } return resp.GetSuccess(), err } type tchanMetaServer struct { handler tchanMeta } func newTChanMetaServer(handler tchanMeta) TChanServer { return &tchanMetaServer{ handler, } } func (s *tchanMetaServer) Service() string { return "Meta" } func (s *tchanMetaServer) Methods() []string { return []string{ "health", "thriftIDL", "versionInfo", } } func (s *tchanMetaServer) Handle(ctx Context, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { case "health": return s.handleHealth(ctx, protocol) case "thriftIDL": return s.handleThriftIDL(ctx, protocol) case "versionInfo": return s.handleVersionInfo(ctx, protocol) default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } func (s *tchanMetaServer) handleHealth(ctx Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req gen.MetaHealthArgs var res gen.MetaHealthResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.Health(ctx, req.Hr) if err != nil { return false, nil, err } res.Success = r return err == nil, &res, nil } func (s *tchanMetaServer) handleThriftIDL(ctx Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req gen.MetaThriftIDLArgs var res gen.MetaThriftIDLResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.ThriftIDL(ctx) if err != nil { return false, nil, err } res.Success = r return err == nil, &res, nil } func (s *tchanMetaServer) handleVersionInfo(ctx Context, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req gen.MetaVersionInfoArgs var res gen.MetaVersionInfoResult if err := req.Read(protocol); err != nil { return false, nil, err } r, err := s.handler.VersionInfo(ctx) if err != nil { return false, nil, err } res.Success = r return err == nil, &res, nil } ================================================ FILE: thrift/test.thrift ================================================ struct Data { 1: required bool b1, 2: required string s2, 3: required i32 i3 } exception SimpleErr { 1: string message } exception NewErr { 1: string message } service SimpleService { Data Call(1: Data arg) void Simple() throws (1: SimpleErr simpleErr) void SimpleFuture() throws (1: SimpleErr simpleErr, 2: NewErr newErr) } service SecondService { string Echo(1: string arg) } struct HealthStatus { 1: required bool ok 2: optional string message } // Meta contains the old health endpoint without arguments. service Meta { HealthStatus health() } ================================================ FILE: thrift/thrift-gen/compile_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "bufio" "fmt" "io" "io/ioutil" "os" "os/exec" "path" "path/filepath" "strings" "sync" "testing" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/testutils" ) // These tests ensure that the code generator generates valid code that can be built // in combination with Thrift's autogenerated code. const _tchannelPackage = "github.com/uber/tchannel-go" var ( _testGoPath string _testGoPathOnce sync.Once ) func TestMain(m *testing.M) { exitCode := m.Run() // If we created a fake GOPATH, we should clean it up on success. if _testGoPath != "" && exitCode == 0 { os.RemoveAll(_testGoPath) } os.Exit(exitCode) } func getTChannelDir(goPath string) string { return filepath.Join(goPath, "src", _tchannelPackage) } func getCurrentTChannelPath(t *testing.T) string { wd, err := os.Getwd() require.NoError(t, err, "Failed to get working directory") // Walk up "wd" till we find "tchannel-go". for filepath.Base(wd) != filepath.Base(_tchannelPackage) { wd = filepath.Dir(wd) if wd == "" { t.Fatalf("Failed to find tchannel-go in parents of current directory") } } return wd } func createGoPath(t *testing.T) { goPath, err := ioutil.TempDir("", "thrift-gen") require.NoError(t, err, "TempDir failed") // Create $GOPATH/src/github.com/uber/tchannel-go and symlink everything. // And then create a dummy directory for all the test output. tchannelDir := getTChannelDir(goPath) require.NoError(t, os.MkdirAll(tchannelDir, 0755), "MkDirAll failed") // Symlink the contents of tchannel-go into the temp directory. realTChannelDir := getCurrentTChannelPath(t) realDirContents, err := ioutil.ReadDir(realTChannelDir) require.NoError(t, err, "Failed to read real tchannel-go dir") for _, f := range realDirContents { realPath := filepath.Join(realTChannelDir, f.Name()) err := os.Symlink(realPath, filepath.Join(tchannelDir, filepath.Base(f.Name()))) require.NoError(t, err, "Failed to symlink %v", f.Name()) } _testGoPath = goPath // None of the other tests in this package should use GOPATH, so we don't // restore this. os.Setenv("GOPATH", goPath) } func getOutputDir(t *testing.T) (dir, pkg string) { _testGoPathOnce.Do(func() { createGoPath(t) }) // Create a random directory inside of the GOPATH in tmp randStr := testutils.RandString(10) randDir := filepath.Join(getTChannelDir(_testGoPath), randStr) // In case it's not empty. os.RemoveAll(randDir) return randDir, filepath.Join(_tchannelPackage, randStr) } func TestAllThrift(t *testing.T) { files, err := ioutil.ReadDir("test_files") require.NoError(t, err, "Cannot read test_files directory: %v", err) for _, f := range files { fname := f.Name() if f.IsDir() || filepath.Ext(fname) != ".thrift" { continue } if err := runBuildTest(t, filepath.Join("test_files", fname)); err != nil { t.Errorf("Thrift file %v failed: %v", fname, err) } } } func TestIncludeThrift(t *testing.T) { dirs, err := ioutil.ReadDir("test_files/include_test") require.NoError(t, err, "Cannot read test_files/include_test directory: %v", err) for _, d := range dirs { dname := d.Name() if !d.IsDir() { continue } thriftFile := filepath.Join(dname, path.Base(dname)+".thrift") if err := runBuildTest(t, filepath.Join("test_files/include_test/", thriftFile)); err != nil { t.Errorf("Thrift test %v failed: %v", dname, err) } } } func TestMultipleFiles(t *testing.T) { if err := runBuildTest(t, filepath.Join("test_files", "multi_test", "file1.thrift")); err != nil { t.Errorf("Multiple file test failed: %v", err) } } func TestExternalTemplate(t *testing.T) { template1 := `package {{ .Package }} {{ range .AST.Services }} // Service {{ .Name }} has {{ len .Methods }} methods. {{ range .Methods }} // func {{ .Name | goPublicName }} ({{ range .Arguments }}{{ .Type | goType }}, {{ end }}) ({{ if .ReturnType }}{{ .ReturnType | goType }}{{ end }}){{ end }} {{ end }} ` templateFile := writeTempFile(t, template1) defer os.Remove(templateFile) expected := `package service_extend // Service S1 has 1 methods. // func M1 ([]byte, ) ([]byte) // Service S2 has 1 methods. // func M2 (*S, int32, ) (*S) // Service S3 has 1 methods. // func M3 () () ` opts := processOptions{ InputFile: "test_files/service_extend.thrift", TemplateFiles: []string{templateFile}, } checks := func(dir string) error { dir = filepath.Join(dir, "service_extend") if err := checkDirectoryFiles(dir, 6); err != nil { return err } // Verify the contents of the extra file. outFile := filepath.Join(dir, defaultPackageName(templateFile)+"-service_extend.go") return verifyFileContents(outFile, expected) } if err := runTest(t, opts, checks); err != nil { t.Errorf("Failed to run test: %v", err) } } func writeTempFile(t *testing.T, contents string) string { tempFile, err := ioutil.TempFile("", "temp") require.NoError(t, err, "Failed to create temp file") tempFile.Close() require.NoError(t, ioutil.WriteFile(tempFile.Name(), []byte(contents), 0666), "Write temp file failed") return tempFile.Name() } func verifyFileContents(filename, expected string) error { bytes, err := ioutil.ReadFile(filename) if err != nil { return err } bytesStr := string(bytes) if bytesStr != expected { return fmt.Errorf("file contents mismatch. got:\n%vexpected:\n%v", bytesStr, expected) } return nil } func copyFile(src, dst string) error { f, err := os.Open(src) if err != nil { return err } defer f.Close() writeF, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) if err != nil { return err } defer writeF.Close() _, err = io.Copy(writeF, f) return err } // setupDirectory creates a temporary directory. func setupDirectory(thriftFile string) (string, error) { tempDir, err := ioutil.TempDir("", "thrift-gen") if err != nil { return "", err } return tempDir, nil } func createAdditionalTestFile(thriftFile, tempDir string) error { f, err := os.Open(thriftFile) if err != nil { return err } var writer io.Writer rdr := bufio.NewReader(f) for { line, err := rdr.ReadString('\n') if err != nil { if err == io.EOF { return nil } } if strings.HasPrefix(line, "//Go code:") { fileName := strings.TrimSpace(strings.TrimPrefix(line, "//Go code:")) outFile := filepath.Join(tempDir, fileName) f, err := os.OpenFile(outFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) if err != nil { return err } defer f.Close() writer = f } else if writer != nil { if strings.HasPrefix(line, "//") { writer.Write([]byte(strings.TrimPrefix(line, "//"))) } else { return nil } } } } func checkDirectoryFiles(dir string, n int) error { dirContents, err := ioutil.ReadDir(dir) if err != nil { return err } if len(dirContents) < n { return fmt.Errorf("expected to generate at least %v files, but found: %v", n, len(dirContents)) } return nil } func runBuildTest(t *testing.T, thriftFile string) error { extraChecks := func(dir string) error { return checkDirectoryFiles(filepath.Join(dir, defaultPackageName(thriftFile)), 4) } opts := processOptions{InputFile: thriftFile} return runTest(t, opts, extraChecks) } func runTest(t *testing.T, opts processOptions, extraChecks func(string) error) error { tempDir, outputPkg := getOutputDir(t) // Generate code from the Thrift file. *packagePrefix = outputPkg + "/" opts.GenerateThrift = true opts.OutputDir = tempDir if err := processFile(opts); err != nil { return fmt.Errorf("processFile(%s) in %q failed: %v", opts.InputFile, tempDir, err) } // Create any extra Go files as specified in the Thrift file. if err := createAdditionalTestFile(opts.InputFile, tempDir); err != nil { return fmt.Errorf("failed creating additional test files for %s in %q: %v", opts.InputFile, tempDir, err) } // Run go build to ensure that the generated code builds. cmd := exec.Command("go", "build", "./...") cmd.Dir = tempDir // NOTE: we check output, since go build ./... returns 0 status code on failure: // https://github.com/golang/go/issues/11407 var ( output, err = cmd.CombinedOutput() outputLines []string ) for _, s := range strings.Split(string(output), "\n") { // Exclude expected output like vendor package downloads and formatting lines if strings.HasPrefix(s, "go: downloading") || strings.TrimSpace(s) == "" { continue } outputLines = append(outputLines, s) } if err != nil || len(outputLines) > 0 { return fmt.Errorf("build in %q failed.\nError: %v Output:\n%v", tempDir, err, string(output)) } // Run any extra checks the user may want. if err := extraChecks(tempDir); err != nil { return err } // Only delete the temp directory on success. os.RemoveAll(tempDir) return nil } ================================================ FILE: thrift/thrift-gen/extends.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "fmt" "sort" "strings" ) // setExtends will set the ExtendsService for all services. // It is done after all files are parsed, as services may extend those // found in an included file. func setExtends(state map[string]parseState) error { for _, v := range state { for _, s := range v.services { if s.Extends == "" { continue } var searchServices []*Service var searchFor string parts := strings.SplitN(s.Extends, ".", 2) // If it's not imported, then look at the current file's services. if len(parts) < 2 { searchServices = v.services searchFor = s.Extends } else { include := v.global.includes[parts[0]] s.ExtendsPrefix = include.pkg + "." searchServices = state[include.file].services searchFor = parts[1] } foundService := sort.Search(len(searchServices), func(i int) bool { return searchServices[i].Name >= searchFor }) if foundService == len(searchServices) { return fmt.Errorf("failed to find base service %q for %q", s.Extends, s.Name) } s.ExtendsService = searchServices[foundService] } } return nil } ================================================ FILE: thrift/thrift-gen/generate.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "flag" "fmt" "os" "os/exec" "path/filepath" "strings" ) var ( thriftBinary = flag.String("thriftBinary", "thrift", "Command to use for the Apache Thrift binary") apacheThriftImport = flag.String("thriftImport", "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift", "Go package to use for the Thrift import") packagePrefix = flag.String("packagePrefix", "", "The package prefix (will be used similar to how Apache Thrift uses it)") ) func execCmd(name string, args ...string) error { cmd := exec.Command(name, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() } func execThrift(args ...string) error { return execCmd(*thriftBinary, args...) } func deleteRemote(dir string) error { return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if !info.IsDir() || !strings.HasSuffix(path, "-remote") { return nil } if err := os.RemoveAll(path); err != nil { return err } // Once the directory is deleted, we can skip the rest of it. return filepath.SkipDir }) } func runThrift(inFile string, outDir string) error { inFile, err := filepath.Abs(inFile) if err != nil { return err } // Delete any existing generated code for this Thrift file. genDir := filepath.Join(outDir, defaultPackageName(inFile)) if err := execCmd("rm", "-rf", genDir); err != nil { return fmt.Errorf("failed to delete directory %s: %v", genDir, err) } // Generate the Apache Thrift generated code. goArgs := fmt.Sprintf("go:thrift_import=%s,package_prefix=%s", *apacheThriftImport, *packagePrefix) if err := execThrift("-r", "--gen", goArgs, "-out", outDir, inFile); err != nil { return fmt.Errorf("thrift compile failed: %v", err) } // Delete the -remote folders. if err := deleteRemote(outDir); err != nil { return fmt.Errorf("failed to delete -remote folders: %v", err) } return nil } ================================================ FILE: thrift/thrift-gen/gopath.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "fmt" "os" "path/filepath" ) // ResolveWithGoPath will resolve the filename relative to GOPATH and returns // the first file that exists, or an error otherwise. func ResolveWithGoPath(filename string) (string, error) { for _, file := range goPathCandidates(filename) { if _, err := os.Stat(file); !os.IsNotExist(err) { return file, nil } } return "", fmt.Errorf("file not found on GOPATH: %q", filename) } func goPathCandidates(filename string) []string { candidates := []string{filename} paths := filepath.SplitList(os.Getenv("GOPATH")) for _, path := range paths { resolvedFilename := filepath.Join(path, "src", filename) candidates = append(candidates, resolvedFilename) } return candidates } ================================================ FILE: thrift/thrift-gen/gopath_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "io/ioutil" "os" "path/filepath" "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func getFakeFS(t *testing.T) string { files := []string{ "src/pkg1/sub/ringpop.thriftgen", "src/pkg2/sub/ringpop.thriftgen", } tempDir, err := ioutil.TempDir("", "thriftgen") require.NoError(t, err, "TempDir failed") for _, f := range files { require.NoError(t, os.MkdirAll(filepath.Join(tempDir, filepath.Dir(f)), 0770), "Failed to create directory structure for %v", f) require.NoError(t, ioutil.WriteFile(filepath.Join(tempDir, f), nil, 0660), "Failed to create dummy file") } return tempDir } func TestGoPathCandidates(t *testing.T) { tests := []struct { goPath string filename string expectedCandidates []string }{ { goPath: "onepath", filename: "github.com/uber/tchannel-go/tchan.thrift-gen", expectedCandidates: []string{ "github.com/uber/tchannel-go/tchan.thrift-gen", "onepath/src/github.com/uber/tchannel-go/tchan.thrift-gen", }, }, { goPath: "onepath:secondpath", filename: "github.com/uber/tchannel-go/tchan.thrift-gen", expectedCandidates: []string{ "github.com/uber/tchannel-go/tchan.thrift-gen", "onepath/src/github.com/uber/tchannel-go/tchan.thrift-gen", "secondpath/src/github.com/uber/tchannel-go/tchan.thrift-gen", }, }, } for _, tt := range tests { os.Setenv("GOPATH", tt.goPath) candidates := goPathCandidates(tt.filename) if !reflect.DeepEqual(candidates, tt.expectedCandidates) { t.Errorf("GOPATH=%s FileCandidatesWithGopath(%s) = %q, want %q", tt.goPath, tt.filename, candidates, tt.expectedCandidates) } } } func TestResolveWithGoPath(t *testing.T) { goPath1 := getFakeFS(t) goPath2 := getFakeFS(t) os.Setenv("GOPATH", goPath1+string(filepath.ListSeparator)+goPath2) defer os.RemoveAll(goPath1) defer os.RemoveAll(goPath2) tests := []struct { filename string want string wantErr bool }{ { filename: "pkg1/sub/ringpop.thriftgen", want: filepath.Join(goPath1, "src/pkg1/sub/ringpop.thriftgen"), }, { filename: "pkg2/sub/ringpop.thriftgen", want: filepath.Join(goPath1, "src/pkg2/sub/ringpop.thriftgen"), }, { filename: filepath.Join(goPath2, "src/pkg2/sub/ringpop.thriftgen"), want: filepath.Join(goPath2, "src/pkg2/sub/ringpop.thriftgen"), }, { filename: "pkg3/sub/ringpop.thriftgen", wantErr: true, }, } for _, tt := range tests { file, err := ResolveWithGoPath(tt.filename) gotErr := err != nil assert.Equal(t, tt.wantErr, gotErr, "%v expected error: %v got: %v", tt.filename, tt.wantErr, err) assert.Equal(t, tt.want, file, "%v expected to resolve to %v, got %v", tt.filename, tt.want, file) } } ================================================ FILE: thrift/thrift-gen/include.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import "github.com/samuel/go-thrift/parser" // Include represents a single include statement in the Thrift file. type Include struct { key string file string pkg string } // Import returns the go import to use for this package. func (i *Include) Import() string { // TODO(prashant): Rename imports so they don't clash with standard imports. // This is not high priority since Apache thrift clashes already with "bytes" and "fmt". // which are the same imports we would clash with. return *packagePrefix + i.Package() } // Package returns the package selector for this package. func (i *Include) Package() string { return i.pkg } func createIncludes(parsed *parser.Thrift, all map[string]parseState) map[string]*Include { includes := make(map[string]*Include) for k, v := range parsed.Includes { included := all[v] includes[k] = &Include{ key: k, file: v, pkg: included.namespace, } } return includes } ================================================ FILE: thrift/thrift-gen/main.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // thrift-gen generates code for Thrift services that can be used with the // uber/tchannel/thrift package. thrift-gen generated code relies on the // Apache Thrift generated code for serialization/deserialization, and should // be a part of the generated code's package. package main import ( "flag" "fmt" "log" "os" "path/filepath" "regexp" "strings" "text/template" "github.com/samuel/go-thrift/parser" ) const tchannelThriftImport = "github.com/uber/tchannel-go/thrift" var ( generateThrift = flag.Bool("generateThrift", false, "Whether to generate all Thrift go code") inputFile = flag.String("inputFile", "", "The .thrift file to generate a client for") outputDir = flag.String("outputDir", "gen-go", "The output directory to generate go code to.") skipTChannel = flag.Bool("skipTChannel", false, "Whether to skip the TChannel template") templateFiles = NewStringSliceFlag("template", "Template file to compile code from") nlSpaceNL = regexp.MustCompile(`\n[ \t]+\n`) ) // TemplateData is the data passed to the template that generates code. type TemplateData struct { Package string AST *parser.Thrift Services []*Service Includes map[string]*Include Imports imports // global should not be directly exported to the template, but functions on // global can be exposed to templates. global *State } type imports struct { Thrift string TChannel string } func main() { flag.Parse() if *inputFile == "" { log.Fatalf("Please specify an inputFile") } opts := processOptions{ InputFile: *inputFile, GenerateThrift: *generateThrift, OutputDir: *outputDir, SkipTChannel: *skipTChannel, TemplateFiles: *templateFiles, } if err := processFile(opts); err != nil { log.Fatal(err) } } type processOptions struct { InputFile string GenerateThrift bool OutputDir string SkipTChannel bool TemplateFiles []string } func processFile(opts processOptions) error { if err := os.MkdirAll(opts.OutputDir, 0770); err != nil { return fmt.Errorf("failed to create output directory %q: %v", opts.OutputDir, err) } if opts.GenerateThrift { if err := runThrift(opts.InputFile, opts.OutputDir); err != nil { return fmt.Errorf("failed to run thrift for file %q: %v", opts.InputFile, err) } } allParsed, err := parseFile(opts.InputFile) if err != nil { return fmt.Errorf("failed to parse file %q: %v", opts.InputFile, err) } allTemplates, err := parseTemplates(opts.SkipTChannel, opts.TemplateFiles) if err != nil { return fmt.Errorf("failed to parse templates: %v", err) } for filename, v := range allParsed { pkg := getNamespace(filename, v.ast) for _, template := range allTemplates { outputFile := filepath.Join(opts.OutputDir, pkg, template.outputFile(pkg)) if err := generateCode(outputFile, template, pkg, v); err != nil { return err } } } return nil } type parseState struct { ast *parser.Thrift namespace string global *State services []*Service } // parseTemplates returns a list of Templates that must be rendered given the template files. func parseTemplates(skipTChannel bool, templateFiles []string) ([]*Template, error) { var templates []*Template if !skipTChannel { templates = append(templates, &Template{ name: "tchan", template: template.Must(parseTemplate(tchannelTmpl)), }) } for _, f := range templateFiles { t, err := parseTemplateFile(f) if err != nil { return nil, err } templates = append(templates, t) } return templates, nil } func parseFile(inputFile string) (map[string]parseState, error) { parser := &parser.Parser{} parsed, _, err := parser.ParseFile(inputFile) if err != nil { return nil, err } allParsed := make(map[string]parseState) for filename, v := range parsed { state := newState(v, allParsed) services, err := wrapServices(v, state) if err != nil { return nil, fmt.Errorf("wrap services failed: %v", err) } namespace := getNamespace(filename, v) allParsed[filename] = parseState{v, namespace, state, services} } setIncludes(allParsed) return allParsed, setExtends(allParsed) } func defaultPackageName(fullPath string) string { filename := filepath.Base(fullPath) file := strings.TrimSuffix(filename, filepath.Ext(filename)) return strings.ToLower(file) } func getNamespace(filename string, v *parser.Thrift) string { if ns, ok := v.Namespaces["go"]; ok { return ns } // TODO(prashant): Remove any characters that are not valid in Go package names. return defaultPackageName(filename) } func generateCode(outputFile string, template *Template, pkg string, state parseState) error { if outputFile == "" { return fmt.Errorf("must speciy an output file") } if len(state.services) == 0 { return nil } td := TemplateData{ Package: pkg, AST: state.ast, Includes: state.global.includes, Services: state.services, global: state.global, Imports: imports{ Thrift: *apacheThriftImport, TChannel: tchannelThriftImport, }, } return template.execute(outputFile, td) } type stringSliceFlag []string func (s *stringSliceFlag) String() string { return strings.Join(*s, ", ") } func (s *stringSliceFlag) Set(in string) error { *s = append(*s, in) return nil } // NewStringSliceFlag creates a new string slice flag. The default value is always nil. func NewStringSliceFlag(name string, usage string) *[]string { var ss stringSliceFlag flag.Var(&ss, name, usage) return (*[]string)(&ss) } ================================================ FILE: thrift/thrift-gen/names.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main // This file implements go name generation for thrift identifiers. // It has to match the Apache Thrift generated names. import "strings" // goKeywords taken from https://golang.org/ref/spec#Keywords (and added error). var goKeywords = map[string]bool{ "error": true, "break": true, "default": true, "func": true, "interface": true, "select": true, "case": true, "defer": true, "go": true, "map": true, "struct": true, "chan": true, "else": true, "goto": true, "package": true, "switch": true, "const": true, "fallthrough": true, "if": true, "range": true, "type": true, "continue": true, "for": true, "import": true, "return": true, "var": true, } // This set is taken from https://github.com/golang/lint/blob/master/lint.go#L692 var commonInitialisms = map[string]bool{ "API": true, "ASCII": true, "CPU": true, "CSS": true, "DNS": true, "EOF": true, "GUID": true, "HTML": true, "HTTP": true, "HTTPS": true, "ID": true, "IP": true, "JSON": true, "LHS": true, "QPS": true, "RAM": true, "RHS": true, "RPC": true, "SLA": true, "SMTP": true, "SQL": true, "SSH": true, "TCP": true, "TLS": true, "TTL": true, "UDP": true, "UI": true, "UID": true, "UUID": true, "URI": true, "URL": true, "UTF8": true, "VM": true, "XML": true, "XSRF": true, "XSS": true, } func goName(name string) string { // Thrift Identifier from IDL: ( Letter | '_' ) ( Letter | Digit | '.' | '_' )* // Go identifier from spec: letter { letter | unicode_digit } . // Go letter allows underscore, so the only difference is period. However, periods cannot // actaully be used - this seems to be a bug in the IDL. if _, ok := goKeywords[name]; ok { // The thrift compiler appends _a1 for any clashes with go keywords. name += "_a1" } name = camelCase(name, false /* publicName */) return name } // camelCase takes a name with underscores such as my_arg and returns camelCase (e.g. myArg). // if publicName is true, then it returns UpperCamelCase. // This method will also fix common initialisms (e.g. ID, API, etc). func camelCase(name string, publicName bool) string { parts := strings.Split(name, "_") startAt := 1 if publicName { startAt = 0 } for i := startAt; i < len(parts); i++ { name := parts[i] if name == "" { continue } // For all words except the first, if the first letter of the word is // uppercase, Thrift keeps the underscore. if i > 0 && strings.ToUpper(name[0:1]) == name[0:1] { name = "_" + name } else { name = strings.ToUpper(name[0:1]) + name[1:] } if isInitialism := commonInitialisms[strings.ToUpper(name)]; isInitialism { name = strings.ToUpper(name) } parts[i] = name } return strings.Join(parts, "") } func avoidThriftClash(name string) string { if strings.HasSuffix(name, "Result") || strings.HasSuffix(name, "Args") || strings.HasPrefix(name, "New") { return name + "_" } return name } // goPublicName returns a go identifier that is exported. func goPublicName(name string) string { return camelCase(name, true /* publicName */) } // goPublicFieldName returns the name of the field as used in a struct. func goPublicFieldName(name string) string { return avoidThriftClash(goPublicName(name)) } var thriftToGo = map[string]string{ "bool": "bool", "byte": "int8", "i16": "int16", "i32": "int32", "i64": "int64", "double": "float64", "string": "string", } ================================================ FILE: thrift/thrift-gen/tchannel-template.go ================================================ package main var tchannelTmpl = ` // @generated Code generated by thrift-gen. Do not modify. // Package {{ .Package }} is generated code used to make or handle TChannel calls using Thrift. package {{ .Package }} import ( "fmt" athrift "{{ .Imports.Thrift }}" "{{ .Imports.TChannel }}" {{ range .Includes }} "{{ .Import }}" {{ end }} ) {{ range .Includes }} var _ = {{ .Package }}.GoUnusedProtection__ {{ end }} // Interfaces for the service and client for the services defined in the IDL. {{ range .Services }} // {{ .Interface }} is the interface that defines the server handler and client interface. type {{ .Interface }} interface { {{ if .HasExtends }} {{ .ExtendsServicePrefix }}{{ .ExtendsService.Interface }} {{ end }} {{ range .Methods }} {{ .Name }}({{ .ArgList }}) {{ .RetType }} {{ end }} } {{ end }} // Implementation of a client and service handler. {{/* Generate client and service implementations for the above interfaces. */}} {{ range $svc := .Services }} type {{ .ClientStruct }} struct { {{ if .HasExtends }} {{ .ExtendsServicePrefix }}{{ .ExtendsService.Interface }} {{ end }} thriftService string client thrift.TChanClient } func {{ .InheritedClientConstructor }}(thriftService string, client thrift.TChanClient) *{{ .ClientStruct }} { return &{{ .ClientStruct }}{ {{ if .HasExtends }} {{ .ExtendsServicePrefix }}{{ .ExtendsService.InheritedClientConstructor }}(thriftService, client), {{ end }} thriftService, client, } } // {{ .ClientConstructor }} creates a client that can be used to make remote calls. func {{ .ClientConstructor }}(client thrift.TChanClient) {{ .Interface }} { return {{ .InheritedClientConstructor }}("{{ .ThriftName }}", client) } {{ range .Methods }} func (c *{{ $svc.ClientStruct }}) {{ .Name }}({{ .ArgList }}) {{ .RetType }} { var resp {{ .ResultType }} args := {{ .ArgsType }}{ {{ range .Arguments }} {{ .ArgStructName }}: {{ .Name }}, {{ end }} } success, err := c.client.Call(ctx, c.thriftService, "{{ .ThriftName }}", &args, &resp) if err == nil && !success { switch { {{ range .Exceptions }} case resp.{{ .ArgStructName }} != nil: err = resp.{{ .ArgStructName }} {{ end }} default: err = fmt.Errorf("received no result or unknown exception for {{ .ThriftName }}") } } {{ if .HasReturn }} return resp.GetSuccess(), err {{ else }} return err {{ end }} } {{ end }} type {{ .ServerStruct }} struct { {{ if .HasExtends }} thrift.TChanServer {{ end }} handler {{ .Interface }} } // {{ .ServerConstructor }} wraps a handler for {{ .Interface }} so it can be // registered with a thrift.Server. func {{ .ServerConstructor }}(handler {{ .Interface }}) thrift.TChanServer { return &{{ .ServerStruct }}{ {{ if .HasExtends }} {{ .ExtendsServicePrefix }}{{ .ExtendsService.ServerConstructor }}(handler), {{ end }} handler, } } func (s *{{ .ServerStruct }}) Service() string { return "{{ .ThriftName }}" } func (s *{{ .ServerStruct }}) Methods() []string { return []string{ {{ range .Methods }} "{{ .ThriftName }}", {{ end }} {{ range .InheritedMethods }} "{{ . }}", {{ end }} } } func (s *{{ .ServerStruct }}) Handle(ctx {{ contextType }}, methodName string, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { switch methodName { {{ range .Methods }} case "{{ .ThriftName }}": return s.{{ .HandleFunc }}(ctx, protocol) {{ end }} {{ range .InheritedMethods }} case "{{ . }}": return s.TChanServer.Handle(ctx, methodName, protocol) {{ end }} default: return false, nil, fmt.Errorf("method %v not found in service %v", methodName, s.Service()) } } {{ range .Methods }} func (s *{{ $svc.ServerStruct }}) {{ .HandleFunc }}(ctx {{ contextType }}, protocol athrift.TProtocol) (bool, athrift.TStruct, error) { var req {{ .ArgsType }} var res {{ .ResultType }} if err := req.Read(protocol); err != nil { return false, nil, err } {{ if .HasReturn }} r, err := {{ else }} err := {{ end }} s.handler.{{ .Name }}({{ .CallList "req" }}) if err != nil { {{ if .HasExceptions }} switch v := err.(type) { {{ range .Exceptions }} case {{ .ArgType }}: if v == nil { return false, nil, fmt.Errorf("Handler for {{ .Name }} returned non-nil error type {{ .ArgType }} but nil value") } res.{{ .ArgStructName }} = v {{ end }} default: return false, nil, err } {{ else }} return false, nil, err {{ end }} } else { {{ if .HasReturn }} res.Success = {{ .WrapResult "r" }} {{ end }} } return err == nil, &res, nil } {{ end }} {{ end }} ` ================================================ FILE: thrift/thrift-gen/template.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // thrift-gen generates code for Thrift services that can be used with the // uber/tchannel/thrift package. thrift-gen generated code relies on the // Apache Thrift generated code for serialization/deserialization, and should // be a part of the generated code's package. package main import ( "bytes" "fmt" "io/ioutil" "os/exec" "text/template" ) // Template represents a single thrift-gen template that will be used to generate code. type Template struct { name string template *template.Template } // dummyGoType is a function to be passed to the test/template package as the named // function 'goType'. This named function is overwritten by an actual implementation // specific to the thrift file being rendered at the time of rendering. func dummyGoType() string { return "" } func parseTemplate(contents string) (*template.Template, error) { funcs := map[string]interface{}{ "contextType": contextType, "goPrivateName": goName, "goPublicName": goPublicName, "goType": dummyGoType, } return template.New("thrift-gen").Funcs(funcs).Parse(contents) } func parseTemplateFile(file string) (*Template, error) { file, err := ResolveWithGoPath(file) if err != nil { return nil, err } bytes, err := ioutil.ReadFile(file) if err != nil { return nil, fmt.Errorf("failed to read file %q: %v", file, err) } t, err := parseTemplate(string(bytes)) if err != nil { return nil, fmt.Errorf("failed to parse template in file %q: %v", file, err) } return &Template{defaultPackageName(file), t}, nil } func contextType() string { return "thrift.Context" } func cleanGeneratedCode(generated []byte) []byte { generated = nlSpaceNL.ReplaceAll(generated, []byte("\n")) return generated } // withStateFuncs adds functions to the template that are dependent upon state. func (t *Template) withStateFuncs(td TemplateData) *template.Template { return t.template.Funcs(map[string]interface{}{ "goType": td.global.goType, }) } func (t *Template) execute(outputFile string, td TemplateData) error { buf := &bytes.Buffer{} if err := t.withStateFuncs(td).Execute(buf, td); err != nil { return fmt.Errorf("failed to execute template: %v", err) } generated := cleanGeneratedCode(buf.Bytes()) if err := ioutil.WriteFile(outputFile, generated, 0660); err != nil { return fmt.Errorf("cannot write output file %q: %v", outputFile, err) } // Run gofmt on the file (ignore any errors) exec.Command("gofmt", "-w", outputFile).Run() return nil } func (t *Template) outputFile(pkg string) string { return fmt.Sprintf("%v-%v.go", t.name, pkg) } ================================================ FILE: thrift/thrift-gen/test_files/binary.thrift ================================================ typedef binary Z struct S { 1: binary s1 2: Z s2 } service Test { binary M1(1: binary arg1) S M2(1: binary arg1, 2: S arg2) } ================================================ FILE: thrift/thrift-gen/test_files/byte.thrift ================================================ typedef byte Z struct S { 1: byte s1 2: Z s2 } service Test { byte M1(1: byte arg1) S M2(1: byte arg1, 2: S arg2) } ================================================ FILE: thrift/thrift-gen/test_files/gokeywords.thrift ================================================ // Test to make sure that reserved names are handled correctly. exception Exception { 1: required string message } struct Result { 1: required string error 2: required i32 func 3: required i32 chan 4: required i32 result 5: required i64 newRole } service func { string func1() void func(1: i32 func) Result chan(1: i32 func, 2: i32 result) throws (1: Exception error) } ================================================ FILE: thrift/thrift-gen/test_files/include_test/namespace/a/shared.thrift ================================================ namespace go a_shared include "../b/shared.thrift" typedef shared.b_string a_string service AShared extends shared.BShared { bool healthA() } ================================================ FILE: thrift/thrift-gen/test_files/include_test/namespace/b/shared.thrift ================================================ namespace go b_shared typedef string b_string service BShared { bool healthB() } ================================================ FILE: thrift/thrift-gen/test_files/include_test/namespace/namespace.thrift ================================================ include "a/shared.thrift" service Foo extends shared.AShared { void Foo(1: shared.a_string str) } ================================================ FILE: thrift/thrift-gen/test_files/include_test/simple/shared.thrift ================================================ include "shared2.thrift" typedef string a_shared_string typedef shared2.a_shared2_string a_shared_string2 typedef shared2.a_shared2_mystruct a_shared_mystruct2 ================================================ FILE: thrift/thrift-gen/test_files/include_test/simple/shared2.thrift ================================================ typedef string a_shared2_string struct MyStruct { 1: string name } typedef MyStruct a_shared2_mystruct ================================================ FILE: thrift/thrift-gen/test_files/include_test/simple/simple.thrift ================================================ include "shared.thrift" include "shared2.thrift" service Foo { void Foo(1: shared.a_shared_string str, 2: shared.a_shared_string2 str2, 3: shared2.MyStruct str3) } ================================================ FILE: thrift/thrift-gen/test_files/include_test/svc_extend/shared.thrift ================================================ typedef string UUID struct Health { 1: bool ok } service FooBase { UUID getUUID() } ================================================ FILE: thrift/thrift-gen/test_files/include_test/svc_extend/svc_extend.thrift ================================================ include "shared.thrift" service Foo extends shared.FooBase { shared.UUID getMyUUID(1: shared.UUID uuid, 2: shared.Health health) shared.Health health(1: shared.UUID uuid, 2: shared.Health health) } //Go code: svc_extend/test.go // package svc_extend // var _ = TChanFoo(nil).GetMyUUID // var _ = TChanFoo(nil).Health // var _ = TChanFoo(nil).GetUUID ================================================ FILE: thrift/thrift-gen/test_files/multi_test/file1.thrift ================================================ include "file2.thrift" service Foo1 { void M() } ================================================ FILE: thrift/thrift-gen/test_files/multi_test/file2.thrift ================================================ service Foo2 { void M() } ================================================ FILE: thrift/thrift-gen/test_files/service_extend.thrift ================================================ struct S { 1: binary s1 } service S1 { binary M1(1: binary bits) } service S2 extends S1 { S M2(1: optional S s, 2: optional i32 i) } service S3 extends S2 { void M3() } //Go code: service_extend/test.go // package service_extend // var _ = TChanS3(nil).M1 // var _ = TChanS3(nil).M2 // var _ = TChanS3(nil).M3 // var _ int32 = S2M2Args{}.I ================================================ FILE: thrift/thrift-gen/test_files/sets.thrift ================================================ service Test { list getInts(1: list nums) set getIntSet(1: set nums) map getIntMap(1: map nums) } ================================================ FILE: thrift/thrift-gen/test_files/test1.thrift ================================================ struct FakeStruct { 1: i64 id 2: i64 user_id } service Fake { // Test initialisms in the method name (as well as name clashes). void id_get() void id() void get_id() void get_Id() void get_ID() // Test initialisms in parameter names. void initialisms_in_args1(1: string LoL_http_TEST_Name) void initialisms_in_args2(1: string user_id) void initialisms_in_args3(1: string id) // Test casing for method names void fAkE(1: i32 func, 2: i32 pkg, 3: FakeStruct fakeStruct) void MyArgs() void MyResult() } ================================================ FILE: thrift/thrift-gen/test_files/typedefs.thrift ================================================ typedef i64 X typedef X Z typedef X Y typedef i64 i typedef i64 func struct S { 1: X x 2: Y y 3: Z z } typedef S ST enum Operator { ADD = 1, SUBTRACT = 2 } service Test { Y M1(1: X arg1, 2: i arg2) X M2(1: Y arg1) Z M3(1: X arg1) S M4(1: S arg1, 2: Operator op) // Thrift compiler is broken on this case. // ST M5(1: ST arg1, 2: S arg2) } ================================================ FILE: thrift/thrift-gen/test_files/union.thrift ================================================ union Constraint { 1: i32 intV 2: string stringV } // thrift-gen generated code assumes there is a service. service Test {} ================================================ FILE: thrift/thrift-gen/typestate.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "strings" "github.com/samuel/go-thrift/parser" ) // State is global Thrift state for a file with type information. type State struct { // typedefs is a map from a typedef name to the underlying type. typedefs map[string]*parser.Type // includes is a map from Thrift base name to the include. includes map[string]*Include // all is used for includes. all map[string]parseState } // newState parses the type information for a parsed Thrift file and returns the state. func newState(v *parser.Thrift, all map[string]parseState) *State { typedefs := make(map[string]*parser.Type) for k, v := range v.Typedefs { typedefs[k] = v.Type } // Enums are typedefs to an int64. i64Type := &parser.Type{Name: "i64"} for k := range v.Enums { typedefs[k] = i64Type } return &State{typedefs, nil, all} } func setIncludes(all map[string]parseState) { for _, v := range all { v.global.includes = createIncludes(v.ast, all) } } func (s *State) isBasicType(thriftType string) bool { _, ok := thriftToGo[thriftType] return ok } // rootType recurses through typedefs and returns the underlying type. func (s *State) rootType(thriftType *parser.Type) *parser.Type { if state, newType, include := s.checkInclude(thriftType); include != nil { return state.rootType(newType) } if v, ok := s.typedefs[thriftType.Name]; ok { return s.rootType(v) } return thriftType } // checkInclude will check if the type is an included type, and if so, return the // state and type from the state for that file. func (s *State) checkInclude(thriftType *parser.Type) (*State, *parser.Type, *Include) { parts := strings.SplitN(thriftType.Name, ".", 2) if len(parts) < 2 { return nil, nil, nil } newType := *thriftType newType.Name = parts[1] include := s.includes[parts[0]] state := s.all[include.file] return state.global, &newType, include } // isResultPointer returns whether the result for this method is a pointer. func (s *State) isResultPointer(thriftType *parser.Type) bool { _, basicGoType := thriftToGo[s.rootType(thriftType).Name] return !basicGoType } // goType returns the Go type name for the given thrift type. func (s *State) goType(thriftType *parser.Type) string { return s.goTypePrefix("", thriftType) } // goTypePrefix returns the Go type name for the given thrift type with the prefix. func (s *State) goTypePrefix(prefix string, thriftType *parser.Type) string { switch thriftType.Name { case "binary": return "[]byte" case "list": return "[]" + s.goType(thriftType.ValueType) case "set": return "map[" + s.goType(thriftType.ValueType) + "]bool" case "map": return "map[" + s.goType(thriftType.KeyType) + "]" + s.goType(thriftType.ValueType) } // If the type is imported, then ignore the package. if state, newType, include := s.checkInclude(thriftType); include != nil { return state.goTypePrefix(include.Package()+".", newType) } // If the type is a direct Go type, use that. if goType, ok := thriftToGo[thriftType.Name]; ok { return goType } goThriftName := goPublicFieldName(thriftType.Name) goThriftName = prefix + goThriftName // Check if the type has a typedef to the direct Go type. rootType := s.rootType(thriftType) if _, ok := thriftToGo[rootType.Name]; ok { return goThriftName } if rootType.Name == "list" || rootType.Name == "set" || rootType.Name == "map" { return goThriftName } // If it's a typedef to another struct, then the typedef is defined as a pointer // so we do not want the pointer type here. if rootType != thriftType { return goThriftName } // If it's not a typedef for a basic type, we use a pointer. return "*" + goThriftName } ================================================ FILE: thrift/thrift-gen/validate.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package main import ( "fmt" "github.com/samuel/go-thrift/parser" ) // Validate validates that the given spec is supported by thrift-gen. func Validate(svc *parser.Service) error { for _, m := range svc.Methods { if err := validateMethod(svc, m); err != nil { return err } } return nil } func validateMethod(svc *parser.Service, m *parser.Method) error { if m.Oneway { return fmt.Errorf("oneway methods are not supported: %s.%v", svc.Name, m.Name) } for _, arg := range m.Arguments { if arg.Optional { // Go treats argument structs as "Required" in the generated code interface. arg.Optional = false } } return nil } ================================================ FILE: thrift/thrift-gen/wrap.go ================================================ package main import ( "fmt" "sort" "strings" "github.com/samuel/go-thrift/parser" ) type byServiceName []*Service func (l byServiceName) Len() int { return len(l) } func (l byServiceName) Less(i, j int) bool { return l[i].Service.Name < l[j].Service.Name } func (l byServiceName) Swap(i, j int) { l[i], l[j] = l[j], l[i] } func wrapServices(v *parser.Thrift, state *State) ([]*Service, error) { var services []*Service for _, s := range v.Services { if err := Validate(s); err != nil { return nil, err } services = append(services, &Service{ Service: s, state: state, }) } // Have a stable ordering for services so code generation is consistent. sort.Sort(byServiceName(services)) return services, nil } // Service is a wrapper for parser.Service. type Service struct { *parser.Service state *State // ExtendsService and ExtendsPrefix are set in `setExtends`. ExtendsService *Service ExtendsPrefix string // methods is a cache of all methods. methods []*Method // inheritedMethods is a list of inherited method names. inheritedMethods []string } // ThriftName returns the thrift identifier for this service. func (s *Service) ThriftName() string { return s.Service.Name } // Interface returns the name of the interface representing the service. func (s *Service) Interface() string { return "TChan" + goPublicName(s.Name) } // ClientStruct returns the name of the unexported struct that satisfies the interface as a client. func (s *Service) ClientStruct() string { return "tchan" + goPublicName(s.Name) + "Client" } // ClientConstructor returns the name of the constructor used to create a client. func (s *Service) ClientConstructor() string { return "NewTChan" + goPublicName(s.Name) + "Client" } // InheritedClientConstructor returns the name of the constructor used by the generated code // for inherited services. This allows the parent service to set the service name that should // be used. func (s *Service) InheritedClientConstructor() string { return "NewTChan" + goPublicName(s.Name) + "InheritedClient" } // ServerStruct returns the name of the unexported struct that satisfies TChanServer. func (s *Service) ServerStruct() string { return "tchan" + goPublicName(s.Name) + "Server" } // ServerConstructor returns the name of the constructor used to create the TChanServer interface. func (s *Service) ServerConstructor() string { return "NewTChan" + goPublicName(s.Name) + "Server" } // HasExtends returns whether this service extends another service. func (s *Service) HasExtends() bool { return s.ExtendsService != nil } // ExtendsServicePrefix returns a package selector (if any) for the extended service. func (s *Service) ExtendsServicePrefix() string { if dotIndex := strings.Index(s.Extends, "."); dotIndex > 0 { return s.ExtendsPrefix } return "" } type byMethodName []*Method func (l byMethodName) Len() int { return len(l) } func (l byMethodName) Less(i, j int) bool { return l[i].Method.Name < l[j].Method.Name } func (l byMethodName) Swap(i, j int) { l[i], l[j] = l[j], l[i] } // Methods returns the methods on this service, not including methods from inherited services. func (s *Service) Methods() []*Method { if s.methods != nil { return s.methods } for _, m := range s.Service.Methods { s.methods = append(s.methods, &Method{m, s, s.state}) } sort.Sort(byMethodName(s.methods)) return s.methods } // InheritedMethods returns names for inherited methods on this service. func (s *Service) InheritedMethods() []string { if s.inheritedMethods != nil { return s.inheritedMethods } for svc := s.ExtendsService; svc != nil; svc = svc.ExtendsService { for m := range svc.Service.Methods { s.inheritedMethods = append(s.inheritedMethods, m) } } sort.Strings(s.inheritedMethods) return s.inheritedMethods } // Method is a wrapper for parser.Method. type Method struct { *parser.Method service *Service state *State } // ThriftName returns the thrift identifier for this function. func (m *Method) ThriftName() string { return m.Method.Name } // Name returns the go method name. func (m *Method) Name() string { return goPublicName(m.Method.Name) } // HandleFunc is the go method name for the handle function which decodes the payload. func (m *Method) HandleFunc() string { return "handle" + goPublicName(m.Method.Name) } // Arguments returns the argument declarations for this method. func (m *Method) Arguments() []*Field { var args []*Field for _, f := range m.Method.Arguments { args = append(args, &Field{f, m.state}) } return args } // Exceptions returns the exceptions that this method may return. func (m *Method) Exceptions() []*Field { var args []*Field for _, f := range m.Method.Exceptions { args = append(args, &Field{f, m.state}) } return args } // HasReturn returns false if this method is declared as void in the Thrift file. func (m *Method) HasReturn() bool { return m.Method.ReturnType != nil } // HasExceptions returns true if this method has func (m *Method) HasExceptions() bool { return len(m.Method.Exceptions) > 0 } func (m *Method) argResPrefix() string { return goPublicName(m.service.Name) + m.Name() } // ArgsType returns the Go name for the struct used to encode the method's arguments. func (m *Method) ArgsType() string { return m.argResPrefix() + "Args" } // ResultType returns the Go name for the struct used to encode the method's result. func (m *Method) ResultType() string { return m.argResPrefix() + "Result" } // ArgList returns the argument list for the function. func (m *Method) ArgList() string { args := []string{"ctx " + contextType()} for _, arg := range m.Arguments() { args = append(args, arg.Declaration()) } return strings.Join(args, ", ") } // CallList creates the call to a function satisfying Interface from an Args struct. func (m *Method) CallList(reqStruct string) string { args := []string{"ctx"} for _, arg := range m.Arguments() { args = append(args, reqStruct+"."+arg.ArgStructName()) } return strings.Join(args, ", ") } // RetType returns the go return type of the method. func (m *Method) RetType() string { if !m.HasReturn() { return "error" } return fmt.Sprintf("(%v, %v)", m.state.goType(m.Method.ReturnType), "error") } // WrapResult wraps the result variable before being used in the result struct. func (m *Method) WrapResult(respVar string) string { if !m.HasReturn() { panic("cannot wrap a return when there is no return mode") } if m.state.isResultPointer(m.ReturnType) { return respVar } return "&" + respVar } // ReturnWith takes the result name and the error name, and generates the return expression. func (m *Method) ReturnWith(respName string, errName string) string { if !m.HasReturn() { return errName } return fmt.Sprintf("%v, %v", respName, errName) } // Field is a wrapper for parser.Field. type Field struct { *parser.Field state *State } // Declaration returns the declaration for this field. func (a *Field) Declaration() string { return fmt.Sprintf("%s %s", a.Name(), a.ArgType()) } // Name returns the field name. func (a *Field) Name() string { return goName(a.Field.Name) } // ArgType returns the Go type for the given field. func (a *Field) ArgType() string { return a.state.goType(a.Type) } // ArgStructName returns the name of this field in the Args struct generated by thrift. func (a *Field) ArgStructName() string { return goPublicFieldName(a.Field.Name) } ================================================ FILE: thrift/thrift_bench_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift_test import ( "flag" "testing" "time" "github.com/uber/tchannel-go/benchmark" "github.com/uber/tchannel-go/testutils" "github.com/stretchr/testify/require" "go.uber.org/atomic" ) const callBatch = 100 var ( useHyperbahn = flag.Bool("useHyperbahn", false, "Whether to advertise and route requests through Hyperbahn") hyperbahnNodes = flag.String("hyperbahn-nodes", "127.0.0.1:21300,127.0.0.1:21301", "Comma-separated list of Hyperbahn nodes") requestSize = flag.Int("request-size", 4, "Call payload size") timeout = flag.Duration("call-timeout", time.Second, "Timeout for each call") ) func init() { benchmark.BenchmarkDir = "../benchmark/" } func BenchmarkBothSerial(b *testing.B) { server := benchmark.NewServer() client := benchmark.NewClient( []string{server.HostPort()}, benchmark.WithTimeout(*timeout), benchmark.WithRequestSize(*requestSize), ) b.ResetTimer() for _, calls := range testutils.Batch(b.N, callBatch) { if _, err := client.ThriftCall(calls); err != nil { b.Errorf("Call failed: %v", err) } } } func BenchmarkInboundSerial(b *testing.B) { server := benchmark.NewServer() client := benchmark.NewClient( []string{server.HostPort()}, benchmark.WithTimeout(*timeout), benchmark.WithExternalProcess(), benchmark.WithRequestSize(*requestSize), ) defer client.Close() require.NoError(b, client.Warmup(), "Warmup failed") b.ResetTimer() for _, calls := range testutils.Batch(b.N, callBatch) { if _, err := client.ThriftCall(calls); err != nil { b.Errorf("Call failed: %v", err) } } } func BenchmarkInboundParallel(b *testing.B) { server := benchmark.NewServer() var reqCounter atomic.Int32 started := time.Now() b.RunParallel(func(pb *testing.PB) { client := benchmark.NewClient( []string{server.HostPort()}, benchmark.WithTimeout(*timeout), benchmark.WithExternalProcess(), benchmark.WithRequestSize(*requestSize), ) defer client.Close() require.NoError(b, client.Warmup(), "Warmup failed") for pb.Next() { if _, err := client.ThriftCall(100); err != nil { b.Errorf("Call failed: %v", err) } reqCounter.Add(100) } }) duration := time.Since(started) reqs := reqCounter.Load() b.Logf("Requests: %v RPS: %v", reqs, float64(reqs)/duration.Seconds()) } ================================================ FILE: thrift/thrift_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift_test import ( "errors" "fmt" "strings" "testing" "time" "golang.org/x/net/context" // Test is in a separate package to avoid circular dependencies. . "github.com/uber/tchannel-go/thrift" tchannel "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" gen "github.com/uber/tchannel-go/thrift/gen-go/test" "github.com/uber/tchannel-go/thrift/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // Generate the service mocks using go generate. //go:generate mockery -dir ./gen-go/test -name TChanSimpleService //go:generate mockery -dir ./gen-go/test -name TChanSecondService type testArgs struct { server *Server s1 *mocks.TChanSimpleService s2 *mocks.TChanSecondService c1 gen.TChanSimpleService c2 gen.TChanSecondService serverCh *tchannel.Channel clientCh *tchannel.Channel } func ctxArg() mock.AnythingOfTypeArgument { return mock.AnythingOfType("tchannel.headerCtx") } func TestThriftArgs(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { arg := &gen.Data{ B1: true, S2: "str", I3: 102, } ret := &gen.Data{ B1: false, S2: "return-str", I3: 105, } args.s1.On("Call", ctxArg(), arg).Return(ret, nil) got, err := args.c1.Call(ctx, arg) require.NoError(t, err) assert.Equal(t, ret, got) }) } func TestRequest(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(nil) require.NoError(t, args.c1.Simple(ctx)) }) } func TestRetryRequest(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { count := 0 args.s1.On("Simple", ctxArg()).Return(tchannel.ErrServerBusy). Run(func(args mock.Arguments) { count++ }) require.Error(t, args.c1.Simple(ctx), "Simple expected to fail") assert.Equal(t, 5, count, "Expected Simple to be retried 5 times") }) } func TestRequestSubChannel(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() tchan := testutils.NewServer(t, testutils.NewOpts().SetServiceName("svc1")) defer tchan.Close() clientCh := testutils.NewClient(t, nil) defer clientCh.Close() clientCh.Peers().Add(tchan.PeerInfo().HostPort) tests := []tchannel.Registrar{tchan, tchan.GetSubChannel("svc2"), tchan.GetSubChannel("svc3")} for _, ch := range tests { mockHandler := new(mocks.TChanSecondService) server := NewServer(ch) server.Register(gen.NewTChanSecondServiceServer(mockHandler)) client := NewClient(clientCh, ch.ServiceName(), nil) secondClient := gen.NewTChanSecondServiceClient(client) echoArg := ch.ServiceName() echoRes := echoArg + "-echo" mockHandler.On("Echo", ctxArg(), echoArg).Return(echoRes, nil) res, err := secondClient.Echo(ctx, echoArg) assert.NoError(t, err, "Echo failed") assert.Equal(t, echoRes, res) } } func TestLargeRequest(t *testing.T) { arg := testutils.RandString(100000) res := strings.ToLower(arg) fmt.Println(len(arg)) withSetup(t, func(ctx Context, args testArgs) { args.s2.On("Echo", ctxArg(), arg).Return(res, nil) got, err := args.c2.Echo(ctx, arg) if assert.NoError(t, err, "Echo got error") { assert.Equal(t, res, got, "Echo got unexpected response") } }) } func TestThriftError(t *testing.T) { thriftErr := &gen.SimpleErr{ Message: "this is the error", } withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(thriftErr) got := args.c1.Simple(ctx) require.Error(t, got) require.Equal(t, thriftErr, got) }) } func TestThriftUnknownError(t *testing.T) { thriftErr := &gen.NewErr_{ Message: "new error", } withSetup(t, func(ctx Context, args testArgs) { // When "Simple" is called, actually call a separate similar looking method // SimpleFuture which has a new exception that the client side of Simple // does not know how to handle. args.s1.On("SimpleFuture", ctxArg()).Return(thriftErr) tClient := NewClient(args.clientCh, args.serverCh.ServiceName(), nil) rewriteMethodClient := rewriteMethodClient{tClient, "SimpleFuture"} simpleClient := gen.NewTChanSimpleServiceClient(rewriteMethodClient) got := simpleClient.Simple(ctx) require.Error(t, got) assert.Contains(t, got.Error(), "no result or unknown exception") }) } func TestThriftNilErr(t *testing.T) { var thriftErr *gen.SimpleErr withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(thriftErr) got := args.c1.Simple(ctx) require.Error(t, got) require.Contains(t, got.Error(), "non-nil error type") require.Contains(t, got.Error(), "nil value") }) } func TestThriftDecodeEmptyFrameServer(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(nil) call, err := args.clientCh.BeginCall(ctx, args.serverCh.PeerInfo().HostPort, args.serverCh.ServiceName(), "SimpleService::Simple", nil) require.NoError(t, err, "Failed to BeginCall") withWriter(t, call.Arg2Writer, func(w tchannel.ArgWriter) error { if err := WriteHeaders(w, nil); err != nil { return err } return w.Flush() }) withWriter(t, call.Arg3Writer, func(w tchannel.ArgWriter) error { if err := WriteStruct(w, &gen.SimpleServiceSimpleArgs{}); err != nil { return err } return w.Flush() }) response := call.Response() withReader(t, response.Arg2Reader, func(r tchannel.ArgReader) error { _, err := ReadHeaders(r) return err }) var res gen.SimpleServiceSimpleResult withReader(t, response.Arg3Reader, func(r tchannel.ArgReader) error { return ReadStruct(r, &res) }) assert.False(t, res.IsSetSimpleErr(), "Expected no error") }) } func TestThriftDecodeEmptyFrameClient(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { handler := func(ctx context.Context, call *tchannel.InboundCall) { withReader(t, call.Arg2Reader, func(r tchannel.ArgReader) error { _, err := ReadHeaders(r) return err }) withReader(t, call.Arg3Reader, func(r tchannel.ArgReader) error { req := &gen.SimpleServiceSimpleArgs{} return ReadStruct(r, req) }) response := call.Response() withWriter(t, response.Arg2Writer, func(w tchannel.ArgWriter) error { if err := WriteHeaders(w, nil); err != nil { return err } return w.Flush() }) withWriter(t, response.Arg3Writer, func(w tchannel.ArgWriter) error { if err := WriteStruct(w, &gen.SimpleServiceSimpleResult{}); err != nil { return err } return w.Flush() }) } args.serverCh.Register(tchannel.HandlerFunc(handler), "SimpleService::Simple") require.NoError(t, args.c1.Simple(ctx)) }) } func TestUnknownError(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(errors.New("unexpected err")) got := args.c1.Simple(ctx) require.Error(t, got) require.Equal(t, tchannel.NewSystemError(tchannel.ErrCodeUnexpected, "unexpected err"), got) }) } func TestMultiple(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(nil) args.s2.On("Echo", ctxArg(), "test1").Return("test2", nil) require.NoError(t, args.c1.Simple(ctx)) res, err := args.c2.Echo(ctx, "test1") require.NoError(t, err) require.Equal(t, "test2", res) }) } func TestHeaders(t *testing.T) { reqHeaders := map[string]string{"header1": "value1", "header2": "value2"} respHeaders := map[string]string{"resp1": "value1-resp", "resp2": "value2-resp"} withSetup(t, func(ctx Context, args testArgs) { args.s1.On("Simple", ctxArg()).Return(nil).Run(func(args mock.Arguments) { ctx := args.Get(0).(Context) assert.Equal(t, reqHeaders, ctx.Headers(), "request headers mismatch") ctx.SetResponseHeaders(respHeaders) }) ctx = WithHeaders(ctx, reqHeaders) require.NoError(t, args.c1.Simple(ctx)) assert.Equal(t, respHeaders, ctx.ResponseHeaders(), "response headers mismatch") }) } func TestClientHostPort(t *testing.T) { ctx, cancel := NewContext(time.Second) defer cancel() s1ch := testutils.NewServer(t, nil) s2ch := testutils.NewServer(t, nil) defer s1ch.Close() defer s2ch.Close() s1ch.Peers().Add(s2ch.PeerInfo().HostPort) s2ch.Peers().Add(s1ch.PeerInfo().HostPort) mock1, mock2 := new(mocks.TChanSecondService), new(mocks.TChanSecondService) NewServer(s1ch).Register(gen.NewTChanSecondServiceServer(mock1)) NewServer(s2ch).Register(gen.NewTChanSecondServiceServer(mock2)) // When we call using a normal client, it can only call the other server (only peer). c1 := gen.NewTChanSecondServiceClient(NewClient(s1ch, s2ch.PeerInfo().ServiceName, nil)) mock2.On("Echo", ctxArg(), "call1").Return("call1", nil) res, err := c1.Echo(ctx, "call1") assert.NoError(t, err, "call1 failed") assert.Equal(t, "call1", res) // When we call using a client that specifies host:port, it should call that server. c2 := gen.NewTChanSecondServiceClient(NewClient(s1ch, s1ch.PeerInfo().ServiceName, &ClientOptions{ HostPort: s1ch.PeerInfo().HostPort, })) mock1.On("Echo", ctxArg(), "call2").Return("call2", nil) res, err = c2.Echo(ctx, "call2") assert.NoError(t, err, "call2 failed") assert.Equal(t, "call2", res) } func TestRegisterPostResponseCB(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { var createdCtx Context ctxKey := "key" ctxValue := "value" args.server.SetContextFn(func(ctx context.Context, method string, headers map[string]string) Context { createdCtx = WithHeaders(context.WithValue(ctx, ctxKey, ctxValue), headers) return createdCtx }) arg := &gen.Data{ B1: true, S2: "str", I3: 102, } ret := &gen.Data{ B1: false, S2: "return-str", I3: 105, } called := make(chan struct{}) cb := func(reqCtx context.Context, method string, response thrift.TStruct) { assert.Equal(t, "Call", method) assert.Equal(t, createdCtx, reqCtx) assert.Equal(t, ctxValue, reqCtx.Value(ctxKey)) res, ok := response.(*gen.SimpleServiceCallResult) if assert.True(t, ok, "response type should be Result struct") { assert.Equal(t, ret, res.GetSuccess(), "result should be returned value") } close(called) } args.server.Register(gen.NewTChanSimpleServiceServer(args.s1), OptPostResponse(cb)) args.s1.On("Call", ctxArg(), arg).Return(ret, nil) res, err := args.c1.Call(ctx, arg) require.NoError(t, err, "Call failed") assert.Equal(t, res, ret, "Call return value wrong") select { case <-time.After(time.Second): t.Errorf("post-response callback not called") case <-called: } }) } func TestRegisterPostResponseCBCalledOnError(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { var createdCtx Context ctxKey := "key" ctxValue := "value" args.server.SetContextFn(func(ctx context.Context, method string, headers map[string]string) Context { createdCtx = WithHeaders(context.WithValue(ctx, ctxKey, ctxValue), headers) return createdCtx }) arg := &gen.Data{ B1: true, S2: "str", I3: 102, } retErr := thrift.NewTProtocolException(fmt.Errorf("expected error")) called := make(chan struct{}) cb := func(reqCtx context.Context, method string, response thrift.TStruct) { assert.Equal(t, "Call", method) assert.Equal(t, createdCtx, reqCtx) assert.Equal(t, ctxValue, reqCtx.Value(ctxKey)) assert.Nil(t, response) close(called) } args.server.Register(gen.NewTChanSimpleServiceServer(args.s1), OptPostResponse(cb)) args.s1.On("Call", ctxArg(), arg).Return(nil, retErr) res, err := args.c1.Call(ctx, arg) require.Error(t, err, "Call succeeded instead of failed") require.Nil(t, res, "Call returned value and an error") sysErr, ok := err.(tchannel.SystemError) require.True(t, ok, "Call return error not a system error") assert.Equal(t, tchannel.ErrCodeBadRequest, sysErr.Code(), "Call return error value wrong") assert.Equal(t, retErr.Error(), sysErr.Message(), "Call return error value wrong") select { case <-time.After(time.Second): t.Errorf("post-response callback not called") case <-called: } }) } func TestThriftTimeout(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { handler := make(chan struct{}) args.s2.On("Echo", ctxArg(), "asd").Return("asd", nil).Run(func(args mock.Arguments) { time.Sleep(testutils.Timeout(150 * time.Millisecond)) close(handler) }) ctx, cancel := NewContext(testutils.Timeout(100 * time.Millisecond)) defer cancel() _, err := args.c2.Echo(ctx, "asd") assert.Equal(t, err, tchannel.ErrTimeout, "Expect call to time out") // Wait for the handler to return, otherwise the test ends before the Server gets an error. select { case <-handler: case <-time.After(time.Second): t.Errorf("Echo handler did not run") } }) } func TestThriftContextFn(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { args.server.SetContextFn(func(ctx context.Context, method string, headers map[string]string) Context { return WithHeaders(ctx, map[string]string{"custom": "headers"}) }) args.s2.On("Echo", ctxArg(), "test").Return("test", nil).Run(func(args mock.Arguments) { ctx := args.Get(0).(Context) assert.Equal(t, "headers", ctx.Headers()["custom"], "Custom header is missing") }) _, err := args.c2.Echo(ctx, "test") assert.NoError(t, err, "Echo failed") }) } func TestThriftMetaHealthNoArgs(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { c := gen.NewTChanMetaClient(NewClient(args.clientCh, args.serverCh.ServiceName(), nil /* options */)) res, err := c.Health(ctx) require.NoError(t, err) assert.True(t, res.Ok, "Health without args failed") }) } func withSetup(t *testing.T, f func(ctx Context, args testArgs)) { args := testArgs{ s1: new(mocks.TChanSimpleService), s2: new(mocks.TChanSecondService), } ctx, cancel := NewContext(time.Second) defer cancel() // Start server args.serverCh, args.server = setupServer(t, args.s1, args.s2) defer args.serverCh.Close() args.clientCh, args.c1, args.c2 = getClients(t, args.serverCh.PeerInfo(), args.serverCh.ServiceName(), args.clientCh) f(ctx, args) args.s1.AssertExpectations(t) args.s2.AssertExpectations(t) } func setupServer(t *testing.T, h *mocks.TChanSimpleService, sh *mocks.TChanSecondService) (*tchannel.Channel, *Server) { ch := testutils.NewServer(t, nil) server := NewServer(ch) server.Register(gen.NewTChanSimpleServiceServer(h)) server.Register(gen.NewTChanSecondServiceServer(sh)) return ch, server } func getClients(t *testing.T, serverInfo tchannel.LocalPeerInfo, svcName string, clientCh *tchannel.Channel) (*tchannel.Channel, gen.TChanSimpleService, gen.TChanSecondService) { ch := testutils.NewClient(t, nil) ch.Peers().Add(serverInfo.HostPort) client := NewClient(ch, svcName, nil) simpleClient := gen.NewTChanSimpleServiceClient(client) secondClient := gen.NewTChanSecondServiceClient(client) return ch, simpleClient, secondClient } func withReader(t *testing.T, readerFn func() (tchannel.ArgReader, error), f func(r tchannel.ArgReader) error) { reader, err := readerFn() require.NoError(t, err, "Failed to get reader") err = f(reader) require.NoError(t, err, "Failed to read contents") require.NoError(t, reader.Close(), "Failed to close reader") } func withWriter(t *testing.T, writerFn func() (tchannel.ArgWriter, error), f func(w tchannel.ArgWriter) error) { writer, err := writerFn() require.NoError(t, err, "Failed to get writer") f(writer) require.NoError(t, err, "Failed to write contents") require.NoError(t, writer.Close(), "Failed to close Writer") } type rewriteMethodClient struct { client TChanClient rewriteTo string } func (c rewriteMethodClient) Call(ctx Context, serviceName, methodName string, req, resp thrift.TStruct) (success bool, err error) { return c.client.Call(ctx, serviceName, c.rewriteTo, req, resp) } ================================================ FILE: thrift/tracing_test.go ================================================ package thrift_test import ( json_encoding "encoding/json" "testing" "github.com/uber/tchannel-go" . "github.com/uber/tchannel-go/testutils/testtracing" "github.com/uber/tchannel-go/thrift" gen "github.com/uber/tchannel-go/thrift/gen-go/test" "golang.org/x/net/context" ) // ThriftHandler tests tracing over Thrift encoding type ThriftHandler struct { gen.TChanSimpleService // leave nil so calls to unimplemented methods panic. TraceHandler thriftClient gen.TChanSimpleService t *testing.T } func requestFromThrift(req *gen.Data) *TracingRequest { r := new(TracingRequest) r.ForwardCount = int(req.I3) return r } func requestToThrift(r *TracingRequest) *gen.Data { return &gen.Data{I3: int32(r.ForwardCount)} } func responseFromThrift(t *testing.T, res *gen.Data) (*TracingResponse, error) { var r TracingResponse if err := json_encoding.Unmarshal([]byte(res.S2), &r); err != nil { return nil, err } return &r, nil } func responseToThrift(t *testing.T, r *TracingResponse) (*gen.Data, error) { jsonBytes, err := json_encoding.Marshal(r) if err != nil { return nil, err } return &gen.Data{S2: string(jsonBytes)}, nil } func (h *ThriftHandler) Call(ctx thrift.Context, arg *gen.Data) (*gen.Data, error) { req := requestFromThrift(arg) res, err := h.HandleCall(ctx, req, func(ctx context.Context, req *TracingRequest) (*TracingResponse, error) { tctx := ctx.(thrift.Context) res, err := h.thriftClient.Call(tctx, requestToThrift(req)) if err != nil { return nil, err } return responseFromThrift(h.t, res) }) if err != nil { return nil, err } return responseToThrift(h.t, res) } func (h *ThriftHandler) firstCall(ctx context.Context, req *TracingRequest) (*TracingResponse, error) { tctx := thrift.Wrap(ctx) res, err := h.thriftClient.Call(tctx, requestToThrift(req)) if err != nil { return nil, err } return responseFromThrift(h.t, res) } func TestThriftTracingPropagation(t *testing.T) { suite := &PropagationTestSuite{ Encoding: EncodingInfo{Format: tchannel.Thrift, HeadersSupported: true}, Register: func(t *testing.T, ch *tchannel.Channel) TracingCall { opts := &thrift.ClientOptions{HostPort: ch.PeerInfo().HostPort} thriftClient := thrift.NewClient(ch, ch.PeerInfo().ServiceName, opts) handler := &ThriftHandler{ TraceHandler: TraceHandler{Ch: ch}, t: t, thriftClient: gen.NewTChanSimpleServiceClient(thriftClient), } // Register Thrift handler server := thrift.NewServer(ch) server.Register(gen.NewTChanSimpleServiceServer(handler)) return handler.firstCall }, TestCases: map[TracerType][]PropagationTestCase{ Noop: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: "", ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: "", ExpectedSpanCount: 0}, }, Mock: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 6}, }, Jaeger: { {ForwardCount: 2, TracingDisabled: true, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 0}, {ForwardCount: 2, TracingDisabled: false, ExpectedBaggage: BaggageValue, ExpectedSpanCount: 6}, }, }, } suite.Run(t) } ================================================ FILE: thrift/transport.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "errors" "io" "sync" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" ) // readerWriterTransport is a transport that reads and writes from the underlying Reader/Writer. type readWriterTransport struct { io.Writer io.Reader readBuf [1]byte writeBuf [1]byte strBuf []byte } var errNoBytesRead = errors.New("no bytes read") func (t *readWriterTransport) Open() error { return nil } func (t *readWriterTransport) Flush() error { return nil } func (t *readWriterTransport) IsOpen() bool { return true } func (t *readWriterTransport) Close() error { return nil } func (t *readWriterTransport) ReadByte() (byte, error) { v := t.readBuf[0:1] var n int var err error for { n, err = t.Read(v) if n > 0 || err != nil { break } } if err == io.EOF && n > 0 { err = nil } return v[0], err } func (t *readWriterTransport) WriteByte(b byte) error { v := t.writeBuf[:1] v[0] = b _, err := t.Write(v) return err } func (t *readWriterTransport) WriteString(s string) (int, error) { // TODO switch to io.StringWriter once we don't need to support < 1.12 type stringWriter interface{ WriteString(string) (int, error) } if sw, ok := t.Writer.(stringWriter); ok { return sw.WriteString(s) } // This path frequently taken since thrift.TBinaryProtocol calls // WriteString a lot, but fragmentingWriter does not implement WriteString; // furthermore it is difficult to add a dual WriteString path to // fragmentingWriter, since hash checksumming does not accept strings. // // Without this, io.WriteString ends up allocating every time. b := append(t.strBuf[:0], s...) t.strBuf = b[:0] return t.Writer.Write(b) } // RemainingBytes returns the max number of bytes (same as Thrift's StreamTransport) as we // do not know how many bytes we have left. func (t *readWriterTransport) RemainingBytes() uint64 { const maxSize = ^uint64(0) return maxSize } var _ thrift.TRichTransport = &readWriterTransport{} type thriftProtocol struct { transport *readWriterTransport protocol *thrift.TBinaryProtocol } var thriftProtocolPool = sync.Pool{ New: func() interface{} { transport := &readWriterTransport{} protocol := thrift.NewTBinaryProtocolTransport(transport) return &thriftProtocol{transport, protocol} }, } func getProtocolWriter(writer io.Writer) *thriftProtocol { wp := thriftProtocolPool.Get().(*thriftProtocol) wp.transport.Reader = nil wp.transport.Writer = writer return wp } func getProtocolReader(reader io.Reader) *thriftProtocol { wp := thriftProtocolPool.Get().(*thriftProtocol) wp.transport.Reader = reader wp.transport.Writer = nil return wp } ================================================ FILE: thrift/transport_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "bytes" "io" "testing" "github.com/uber/tchannel-go/testutils/testreader" "github.com/uber/tchannel-go/testutils/testwriter" "github.com/stretchr/testify/assert" ) func writeByte(writer io.Writer, b byte) error { protocol := getProtocolWriter(writer) return protocol.transport.WriteByte(b) } func TestWriteByteSuccess(t *testing.T) { writer := &bytes.Buffer{} assert.NoError(t, writeByte(writer, 'a'), "WriteByte failed") assert.NoError(t, writeByte(writer, 'b'), "WriteByte failed") assert.NoError(t, writeByte(writer, 'c'), "WriteByte failed") assert.Equal(t, []byte("abc"), writer.Bytes(), "Written bytes mismatch") } func TestWriteByteFailed(t *testing.T) { buf := &bytes.Buffer{} writer := io.MultiWriter(testwriter.Limited(2), buf) assert.NoError(t, writeByte(writer, 'a'), "WriteByte failed") assert.NoError(t, writeByte(writer, 'b'), "WriteByte failed") assert.Error(t, writeByte(writer, 'c'), "WriteByte should fail due to lack of space") assert.Equal(t, []byte("ab"), buf.Bytes(), "Written bytes mismatch") } func TestReadByte0Byte(t *testing.T) { chunkWriter, chunkReader := testreader.ChunkReader() reader := getProtocolReader(chunkReader) chunkWriter <- []byte{} chunkWriter <- []byte{} chunkWriter <- []byte{} chunkWriter <- []byte("abc") close(chunkWriter) b, err := reader.transport.ReadByte() assert.NoError(t, err, "ReadByte should ignore 0 byte reads") assert.EqualValues(t, 'a', b) b, err = reader.transport.ReadByte() assert.NoError(t, err, "ReadByte failed") assert.EqualValues(t, 'b', b) b, err = reader.transport.ReadByte() assert.NoError(t, err, "ReadByte failed") assert.EqualValues(t, 'c', b) b, err = reader.transport.ReadByte() assert.Equal(t, io.EOF, err, "ReadByte should EOF") } ================================================ FILE: tnet/listener.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tnet import ( "net" "sync" ) // Wrap returns a new Listener around the provided net.Listener. // The returned Listener has a guarantee that when Close returns, it will no longer // accept any new connections. // See: https://github.com/uber/tchannel-go/issues/141 func Wrap(l net.Listener) net.Listener { return &listener{Listener: l, cond: sync.NewCond(&sync.Mutex{})} } // listener wraps a net.Listener and ensures that once Listener.Close returns, // the underlying socket has been closed. // // The default Listener returns from Close before the underlying socket has been closed // if another goroutine has an active reference (e.g. is in Accept). // The following can happen: // Goroutine 1 is running Accept, and is blocked, waiting for epoll // Goroutine 2 calls Close. It sees an extra reference, and so cannot destroy // the socket, but instead decrements a reference, marks the connection as closed // and unblocks epoll. // // Goroutine 2 returns to the caller, makes a new connection. // The new connection is sent to the socket (since it hasn't been destroyed) // Goroutine 1 returns from epoll, and accepts the new connection. // // To avoid accepting connections after Close, we block Goroutine 2 from returning from Close // till Accept returns an error to the user. type listener struct { net.Listener // cond is used signal Close when there are no references to the listener. cond *sync.Cond refs int } func (s *listener) incRef() { s.cond.L.Lock() s.refs++ s.cond.L.Unlock() } func (s *listener) decRef() { s.cond.L.Lock() s.refs-- newRefs := s.refs s.cond.L.Unlock() if newRefs == 0 { s.cond.Broadcast() } } // Accept waits for and returns the next connection to the listener. func (s *listener) Accept() (net.Conn, error) { s.incRef() defer s.decRef() return s.Listener.Accept() } // Close closes the listener. // Any blocked Accept operations will be unblocked and return errors. func (s *listener) Close() error { if err := s.Listener.Close(); err != nil { return err } s.cond.L.Lock() for s.refs > 0 { s.cond.Wait() } s.cond.L.Unlock() return nil } ================================================ FILE: tnet/listener_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tnet import ( "errors" "net" "sync" "testing" "github.com/stretchr/testify/assert" ) func TestListenerAcceptAfterClose(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 16; i++ { wg.Add(1) go func() { defer wg.Done() for i := 0; i < 10; i++ { runTest(t) } }() } wg.Wait() } func runTest(t *testing.T) { const connectionsBeforeClose = 1 ln, err := net.Listen("tcp", "127.0.0.1:0") if !assert.NoError(t, err, "Listen failed") { return } ln = Wrap(ln) addr := ln.Addr().String() waitForListener := make(chan error) go func() { defer close(waitForListener) var connCount int for { conn, err := ln.Accept() if err != nil { return } connCount++ if connCount > connectionsBeforeClose { waitForListener <- errors.New("got unexpected conn") return } conn.Close() } }() for i := 0; i < connectionsBeforeClose; i++ { err := connect(addr) if !assert.NoError(t, err, "connect before listener is closed should succeed") { return } } ln.Close() connect(addr) err = <-waitForListener assert.NoError(t, err, "got connection after listener was closed") } func connect(addr string) error { conn, err := net.Dial("tcp", addr) if err == nil { conn.Close() } return err } ================================================ FILE: tos/tos.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tos // ToS represents a const value DF, CS3 etc // Assured Forwarding (x=class, y=drop precedence) (RFC2597) // Class Selector (RFC 2474) // IP Precedence (Linux Socket Compat RFC 791 type ToS uint8 // Assured Forwarding (x=class, y=drop precedence) (RFC2597) // Class Selector (RFC 2474) const ( // CS3 Class Selector 3 CS3 ToS = 0x18 // CS4 Class Selector 4 CS4 ToS = 0x20 // CS5 Class Selector 5 CS5 ToS = 0x28 // CS6 Class Selector 6 CS6 ToS = 0x30 // CS7 Class Selector 7 CS7 ToS = 0x38 // AF11 Assured Forward 11 AF11 ToS = 0x0a // AF12 Assured Forward 11 AF12 ToS = 0x0c // AF13 Assured Forward 12 AF13 ToS = 0x0e // AF21 Assured Forward 13 AF21 ToS = 0x12 // AF22 Assured Forward 21 AF22 ToS = 0x14 // AF23 Assured Forward 22 AF23 ToS = 0x16 // AF31 Assured Forward 23 AF31 ToS = 0x1a // AF32 Assured Forward 31 AF32 ToS = 0x1c // AF33 Assured Forward 32 AF33 ToS = 0x1e // AF41 Assured Forward 33 AF41 ToS = 0x22 // AF42 Assured Forward 41 AF42 ToS = 0x24 // AF43 Assured Forward 42 AF43 ToS = 0x26 // EF Expedited Forwarding (RFC 3246) EF ToS = 0x2e // Lowdelay 10 Lowdelay ToS = 0x10 // Throughput 8 Throughput ToS = 0x08 // Reliability 4 Reliability ToS = 0x04 // Lowcost 2 Lowcost ToS = 0x02 ) ================================================ FILE: tos/tos_string.go ================================================ package tos import "fmt" var ( _tosNameToValue map[string]ToS _tosValueToName = map[ToS]string{ CS3: "CS3", CS4: "CS4", CS5: "CS5", CS6: "CS6", CS7: "CS7", AF11: "AF11", AF12: "AF12", AF13: "AF13", AF21: "AF21", AF22: "AF22", AF23: "AF23", AF31: "AF31", AF32: "AF32", AF33: "AF33", AF41: "AF41", AF42: "AF42", AF43: "AF43", EF: "EF", Lowdelay: "Lowdelay", Throughput: "Throughput", Reliability: "Reliability", Lowcost: "Lowcost", } ) func init() { _tosNameToValue = make(map[string]ToS, len(_tosValueToName)) for tos, tosString := range _tosValueToName { _tosNameToValue[tosString] = tos } } // MarshalText implements TextMarshaler from encoding func (r ToS) MarshalText() ([]byte, error) { return []byte(_tosValueToName[r]), nil } // UnmarshalText implements TextUnMarshaler from encoding func (r *ToS) UnmarshalText(data []byte) error { if v, ok := _tosNameToValue[string(data)]; ok { *r = v return nil } return fmt.Errorf("invalid ToS %q", string(data)) } ================================================ FILE: tos/tos_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tos import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMarshal(t *testing.T) { for tos, wantMarshalled := range _tosValueToName { marshalled, err := tos.MarshalText() require.NoError(t, err, "Failed to marshal %v", tos) assert.Equal(t, wantMarshalled, string(marshalled)) var got ToS err = got.UnmarshalText(marshalled) require.NoError(t, err, "Failed to unmarshal %v", string(marshalled)) assert.Equal(t, tos, got) } } func TestUnmarshalUnknown(t *testing.T) { var tos ToS err := tos.UnmarshalText([]byte("unknown")) require.Error(t, err, "Should fail to unmarshal unknown value") } ================================================ FILE: trace/doc.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. /* Package trace used to contain TChannel's distributed tracing functionality. It has since been replaced by integration with OpenTracing API. See http://opentracing.io This package is kept to alleviate problems with `glide update`, which tries to look for it during the dependencies upgrades. */ package trace ================================================ FILE: tracing.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "time" "github.com/uber/tchannel-go/trand" "github.com/uber/tchannel-go/typed" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "golang.org/x/net/context" ) // zipkinSpanFormat defines a name for OpenTracing carrier format that tracer may support. // It is used to extract zipkin-style trace/span IDs from the OpenTracing Span, which are // otherwise not exposed explicitly. // NB: the string value is what's actually shared between implementations const zipkinSpanFormat = "zipkin-span-format" const componentName = "tchannel-go" // Span is an internal representation of Zipkin-compatible OpenTracing Span. // It is used as OpenTracing inject/extract Carrier with ZipkinSpanFormat. type Span struct { traceID uint64 parentID uint64 spanID uint64 flags byte } var ( // traceRng is a thread-safe random number generator for generating trace IDs. traceRng = trand.NewSeeded() // emptySpan is returned from CurrentSpan(ctx) when there is no OpenTracing // Span in ctx, to avoid returning nil. emptySpan Span ) func (s Span) String() string { return fmt.Sprintf("TraceID=%x,ParentID=%x,SpanID=%x", s.traceID, s.parentID, s.spanID) } func (s *Span) read(r *typed.ReadBuffer) error { s.spanID = r.ReadUint64() s.parentID = r.ReadUint64() s.traceID = r.ReadUint64() s.flags = r.ReadSingleByte() return r.Err() } func (s *Span) write(w *typed.WriteBuffer) error { w.WriteUint64(s.spanID) w.WriteUint64(s.parentID) w.WriteUint64(s.traceID) w.WriteSingleByte(s.flags) return w.Err() } func (s *Span) initRandom() { s.traceID = uint64(traceRng.Int63()) s.spanID = s.traceID s.parentID = 0 } // TraceID returns the trace id for the entire call graph of requests. Established // at the outermost edge service and propagated through all calls func (s Span) TraceID() uint64 { return s.traceID } // ParentID returns the id of the parent span in this call graph func (s Span) ParentID() uint64 { return s.parentID } // SpanID returns the id of this specific RPC func (s Span) SpanID() uint64 { return s.spanID } // Flags returns flags bitmap. Interpretation of the bits is up to the tracing system. func (s Span) Flags() byte { return s.flags } type injectableSpan Span // SetTraceID sets traceID func (s *injectableSpan) SetTraceID(traceID uint64) { s.traceID = traceID } // SetSpanID sets spanID func (s *injectableSpan) SetSpanID(spanID uint64) { s.spanID = spanID } // SetParentID sets parentID func (s *injectableSpan) SetParentID(parentID uint64) { s.parentID = parentID } // SetFlags sets flags func (s *injectableSpan) SetFlags(flags byte) { s.flags = flags } // initFromOpenTracing initializes injectableSpan fields from an OpenTracing Span, // assuming the tracing implementation supports Zipkin-style span IDs. func (s *injectableSpan) initFromOpenTracing(span opentracing.Span) error { return span.Tracer().Inject(span.Context(), zipkinSpanFormat, s) } // CurrentSpan extracts OpenTracing Span from the Context, and if found tries to // extract zipkin-style trace/span IDs from it using ZipkinSpanFormat carrier. // If there is no OpenTracing Span in the Context, an empty span is returned. func CurrentSpan(ctx context.Context) *Span { if sp := opentracing.SpanFromContext(ctx); sp != nil { var injectable injectableSpan if err := injectable.initFromOpenTracing(sp); err == nil { span := Span(injectable) return &span } // return empty span on error, instead of possibly a partially filled one } return &emptySpan } // startOutboundSpan creates a new tracing span to represent the outbound RPC call. // If the context already contains a span, it will be used as a parent, otherwise // a new root span is created. // // If the tracer supports Zipkin-style trace IDs, then call.callReq.Tracing is // initialized with those IDs. Otherwise it is assigned random values. func (c *Connection) startOutboundSpan(ctx context.Context, serviceName, methodName string, call *OutboundCall, startTime time.Time) opentracing.Span { var parent opentracing.SpanContext // ok to be nil if s := opentracing.SpanFromContext(ctx); s != nil { parent = s.Context() } span := c.Tracer().StartSpan( methodName, opentracing.ChildOf(parent), opentracing.StartTime(startTime), ) if isTracingDisabled(ctx) { ext.SamplingPriority.Set(span, 0) } ext.SpanKindRPCClient.Set(span) ext.PeerService.Set(span, serviceName) ext.Component.Set(span, componentName) c.setPeerHostPort(span) span.SetTag("as", call.callReq.Headers[ArgScheme]) var injectable injectableSpan if err := injectable.initFromOpenTracing(span); err == nil { call.callReq.Tracing = Span(injectable) } else { call.callReq.Tracing.initRandom() } return span } // InjectOutboundSpan retrieves OpenTracing Span from `response`, where it is stored // when the outbound call is initiated. The tracing API is used to serialize the span // into the application `headers`, which will propagate tracing context to the server. // Returns modified headers containing serialized tracing context. // // Sometimes caller pass a shared instance of the `headers` map, so instead of modifying // it we clone it into the new map (assuming that Tracer actually injects some tracing keys). func InjectOutboundSpan(response *OutboundCallResponse, headers map[string]string) map[string]string { span := response.span if span == nil { return headers } newHeaders := make(map[string]string) carrier := tracingHeadersCarrier(newHeaders) if err := span.Tracer().Inject(span.Context(), opentracing.TextMap, carrier); err != nil { // Something had to go seriously wrong for Inject to fail, usually a setup problem. // A good Tracer implementation may also emit a metric. response.log.WithFields(ErrField(err)).Error("Failed to inject tracing span.") } if len(newHeaders) == 0 { return headers // Tracer did not add any tracing headers, so return the original map } for k, v := range headers { // Some applications propagate all inbound application headers to outbound calls (issue #682). // If those headers include tracing headers we want to make sure to keep the new tracing headers. if _, ok := newHeaders[k]; !ok { newHeaders[k] = v } } return newHeaders } // extractInboundSpan attempts to create a new OpenTracing Span for inbound request // using only trace IDs stored in the frame's tracing field. It only works if the // tracer understand Zipkin-style trace IDs. If such attempt fails, another attempt // will be made from the higher level function ExtractInboundSpan() once the // application headers are read from the wire. func (c *Connection) extractInboundSpan(callReq *callReq) opentracing.Span { spanCtx, err := c.Tracer().Extract(zipkinSpanFormat, &callReq.Tracing) if err != nil { if err != opentracing.ErrUnsupportedFormat && err != opentracing.ErrSpanContextNotFound { c.log.WithFields(ErrField(err)).Error("Failed to extract Zipkin-style span.") } return nil } if spanCtx == nil { return nil } operationName := "" // not known at this point, will be set later span := c.Tracer().StartSpan(operationName, ext.RPCServerOption(spanCtx)) span.SetTag("as", callReq.Headers[ArgScheme]) ext.PeerService.Set(span, callReq.Headers[CallerName]) ext.Component.Set(span, componentName) c.setPeerHostPort(span) return span } // ExtractInboundSpan is a higher level version of extractInboundSpan(). // If the lower-level attempt to create a span from incoming request was // successful (e.g. when then Tracer supports Zipkin-style trace IDs), // then the application headers are only used to read the Baggage and add // it to the existing span. Otherwise, the standard OpenTracing API supported // by all tracers is used to deserialize the tracing context from the // application headers and start a new server-side span. // Once the span is started, it is wrapped in a new Context, which is returned. func ExtractInboundSpan(ctx context.Context, call *InboundCall, headers map[string]string, tracer opentracing.Tracer) context.Context { var span = call.Response().span if span != nil { if headers != nil { // extract SpanContext from headers, but do not start another span with it, // just get the baggage and copy to the already created span carrier := tracingHeadersCarrier(headers) if sc, err := tracer.Extract(opentracing.TextMap, carrier); err == nil { sc.ForeachBaggageItem(func(k, v string) bool { span.SetBaggageItem(k, v) return true }) } carrier.RemoveTracingKeys() } } else { var parent opentracing.SpanContext if headers != nil { carrier := tracingHeadersCarrier(headers) if p, err := tracer.Extract(opentracing.TextMap, carrier); err == nil { parent = p } carrier.RemoveTracingKeys() } span = tracer.StartSpan(call.MethodString(), ext.RPCServerOption(parent)) ext.PeerService.Set(span, call.CallerName()) ext.Component.Set(span, componentName) span.SetTag("as", string(call.Format())) call.conn.setPeerHostPort(span) call.Response().span = span } return opentracing.ContextWithSpan(ctx, span) } func (c *Connection) setPeerHostPort(span opentracing.Span) { if c.remotePeerAddress.ipv4 != 0 { ext.PeerHostIPv4.Set(span, c.remotePeerAddress.ipv4) } if c.remotePeerAddress.ipv6 != "" { ext.PeerHostIPv6.Set(span, c.remotePeerAddress.ipv6) } if c.remotePeerAddress.hostname != "" { ext.PeerHostname.Set(span, c.remotePeerAddress.hostname) } if c.remotePeerAddress.port != 0 { ext.PeerPort.Set(span, c.remotePeerAddress.port) } } type tracerProvider interface { Tracer() opentracing.Tracer } // TracerFromRegistrar returns an OpenTracing Tracer embedded in the Registrar, // assuming that Registrar has a Tracer() method. Otherwise it returns default Global Tracer. func TracerFromRegistrar(registrar Registrar) opentracing.Tracer { if tracerProvider, ok := registrar.(tracerProvider); ok { return tracerProvider.Tracer() } return opentracing.GlobalTracer() } ================================================ FILE: tracing_internal_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "fmt" "net" "testing" "github.com/uber/tchannel-go/typed" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) func TestTracingSpanEncoding(t *testing.T) { s1 := Span{ traceID: 1, parentID: 2, spanID: 3, flags: 4, } // Encoding is: spanid:8 parentid:8 traceid:8 traceflags:1 // http://tchannel.readthedocs.io/en/latest/protocol/#tracing encoded := []byte{ 0, 0, 0, 0, 0, 0, 0, 3, /* spanID */ 0, 0, 0, 0, 0, 0, 0, 2, /* parentID */ 0, 0, 0, 0, 0, 0, 0, 1, /* traceID */ 4, /* flags */ } buf := make([]byte, len(encoded)) writer := typed.NewWriteBuffer(buf) require.NoError(t, s1.write(writer), "Failed to encode span") assert.Equal(t, encoded, buf, "Encoded span mismatch") var s2 Span reader := typed.NewReadBuffer(buf) require.NoError(t, s2.read(reader), "Failed to decode span") assert.Equal(t, s1, s2, "Roundtrip of span failed") } func TestTracingInjectorExtractor(t *testing.T) { tracer := mocktracer.New() tracer.RegisterInjector(zipkinSpanFormat, new(zipkinInjector)) tracer.RegisterExtractor(zipkinSpanFormat, new(zipkinExtractor)) sp := tracer.StartSpan("x") var injectable injectableSpan err := tracer.Inject(sp.Context(), zipkinSpanFormat, &injectable) require.NoError(t, err) tsp := Span(injectable) assert.NotEqual(t, uint64(0), tsp.TraceID()) assert.NotEqual(t, uint64(0), tsp.SpanID()) sp2, err := tracer.Extract(zipkinSpanFormat, &tsp) require.NoError(t, err) require.NotNil(t, sp2) } func TestSpanString(t *testing.T) { span := Span{traceID: 15} assert.Equal(t, "TraceID=f,ParentID=0,SpanID=0", span.String()) } func TestSetPeerHostPort(t *testing.T) { tracer := mocktracer.New() ipv6 := []byte{1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 16} assert.Equal(t, net.IPv6len, len(ipv6)) ipv6hostPort := fmt.Sprintf("[%v]:789", net.IP(ipv6)) tests := []struct { hostPort string wantHostTag string wantHost interface{} wantPort uint16 }{ {"adhoc123:bad-port", "peer.hostname", "adhoc123", 0}, {"adhoc123", "peer.hostname", "adhoc123", 0}, {"ip123.uswest.aws.com:765", "peer.hostname", "ip123.uswest.aws.com", 765}, {"localhost:123", "peer.ipv4", uint32(127<<24 | 1), 123}, {"10.20.30.40:321", "peer.ipv4", uint32(10<<24 | 20<<16 | 30<<8 | 40), 321}, {ipv6hostPort, "peer.ipv6", "102:300::f10", 789}, } for i, test := range tests { span := tracer.StartSpan("x") peerInfo, peerAddress, err := parseRemotePeer(initParams{ InitParamHostPort: test.hostPort, InitParamProcessName: "test", }, &net.IPAddr{IP: net.ParseIP("1.1.1.1")}) require.NoError(t, err, "Failed to parse remote peer info") c := &Connection{ channelConnectionCommon: channelConnectionCommon{ log: NullLogger, }, remotePeerInfo: peerInfo, remotePeerAddress: peerAddress, } c.setPeerHostPort(span) span.Finish() rawSpan := tracer.FinishedSpans()[i] assert.Equal(t, test.wantHost, rawSpan.Tag(test.wantHostTag), "test %+v", test) if test.wantPort != 0 { assert.Equal(t, test.wantPort, rawSpan.Tag(string(ext.PeerPort)), "test %+v", test) } else { assert.Nil(t, rawSpan.Tag(string(ext.PeerPort)), "test %+v", test) } } } func TestExtractInboundSpanWithZipkinTracer(t *testing.T) { tracer := mocktracer.New() callReq := new(callReq) callReq.Tracing = Span{traceID: 1, spanID: 2, flags: 1} callReq.Headers = transportHeaders{ ArgScheme: string(JSON), CallerName: "caller", } peerInfo, peerAddress, err := parseRemotePeer(initParams{ InitParamHostPort: "host:123", InitParamProcessName: "test", }, &net.IPAddr{IP: net.ParseIP("1.1.1.1")}) require.NoError(t, err, "Failed to parse remote peer info") c := Connection{ channelConnectionCommon: channelConnectionCommon{ log: NullLogger, tracer: tracer, }, remotePeerInfo: peerInfo, remotePeerAddress: peerAddress, } // fail to extract with zipkin format, as MockTracer does not support it assert.Nil(t, c.extractInboundSpan(callReq), "zipkin format not available") // add zipkin format extractor and try again tracer.RegisterExtractor(zipkinSpanFormat, new(zipkinExtractor)) span := c.extractInboundSpan(callReq) require.NotNil(t, span, "zipkin format available") // validate the extracted span was correctly populated s1, ok := span.(*mocktracer.MockSpan) require.True(t, ok) assert.Equal(t, 1, s1.SpanContext.TraceID) assert.Equal(t, 2, s1.ParentID) assert.True(t, s1.SpanContext.Sampled) assert.Equal(t, "", s1.OperationName, "operation name unknown initially") assert.Equal(t, string(JSON), s1.Tag("as")) assert.Equal(t, "caller", s1.Tag(string(ext.PeerService))) assert.Equal(t, "host", s1.Tag(string(ext.PeerHostname))) assert.Equal(t, uint16(123), s1.Tag(string(ext.PeerPort))) // start a temporary span so that we can populate headers with baggage tempSpan := tracer.StartSpan("test") tempSpan.SetBaggageItem("x", "y") headers := make(map[string]string) carrier := tracingHeadersCarrier(headers) err = tracer.Inject(tempSpan.Context(), opentracing.TextMap, carrier) assert.NoError(t, err) // run the public ExtractInboundSpan method with application headers inCall := &InboundCall{ response: &InboundCallResponse{ span: span, }, } ctx := context.Background() ctx2 := ExtractInboundSpan(ctx, inCall, headers, tracer) span = opentracing.SpanFromContext(ctx2) s2, ok := span.(*mocktracer.MockSpan) require.True(t, ok) assert.Equal(t, s1, s2, "should be the same span started previously") assert.Equal(t, "y", s2.BaggageItem("x"), "baggage should've been added") } type zipkinInjector struct{} func (z *zipkinInjector) Inject(sc mocktracer.MockSpanContext, carrier interface{}) error { span, ok := carrier.(*injectableSpan) if !ok { return opentracing.ErrInvalidCarrier } span.SetTraceID(uint64(sc.TraceID)) span.SetSpanID(uint64(sc.SpanID)) if sc.Sampled { span.SetFlags(1) } else { span.SetFlags(0) } return nil } type zipkinExtractor struct{} func (z *zipkinExtractor) Extract(carrier interface{}) (mocktracer.MockSpanContext, error) { span, ok := carrier.(*Span) if !ok { return mocktracer.MockSpanContext{}, opentracing.ErrInvalidCarrier } return mocktracer.MockSpanContext{ TraceID: int(span.traceID), SpanID: int(span.spanID), Sampled: span.flags&1 == 1, }, nil } ================================================ FILE: tracing_keys.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel import ( "strings" "sync" ) // tracingKeyPrefix is used to prefix all keys used by the OpenTracing Tracer to represent // its trace context and baggage. The prefixing is done in order to distinguish tracing // headers from the actual application headers and to hide the former from the user code. const tracingKeyPrefix = "$tracing$" // tracingKeyMappingSize is the maximum number of tracing key mappings we cache. const tracingKeyMappingSize = 100 type tracingKeysMapping struct { sync.RWMutex mapping map[string]string mapper func(key string) string } var tracingKeyEncoding = &tracingKeysMapping{ mapping: make(map[string]string), mapper: func(key string) string { return tracingKeyPrefix + key }, } var tracingKeyDecoding = &tracingKeysMapping{ mapping: make(map[string]string), mapper: func(key string) string { return key[len(tracingKeyPrefix):] }, } func (m *tracingKeysMapping) mapAndCache(key string) string { m.RLock() v, ok := m.mapping[key] m.RUnlock() if ok { return v } m.Lock() defer m.Unlock() if v, ok := m.mapping[key]; ok { return v } mappedKey := m.mapper(key) if len(m.mapping) < tracingKeyMappingSize { m.mapping[key] = mappedKey } return mappedKey } type tracingHeadersCarrier map[string]string // Set implements Set() of opentracing.TextMapWriter func (c tracingHeadersCarrier) Set(key, val string) { prefixedKey := tracingKeyEncoding.mapAndCache(key) c[prefixedKey] = val } // ForeachKey conforms to the TextMapReader interface. func (c tracingHeadersCarrier) ForeachKey(handler func(key, val string) error) error { for k, v := range c { if !strings.HasPrefix(k, tracingKeyPrefix) { continue } noPrefixKey := tracingKeyDecoding.mapAndCache(k) if err := handler(noPrefixKey, v); err != nil { return err } } return nil } func (c tracingHeadersCarrier) RemoveTracingKeys() { for key := range c { if strings.HasPrefix(key, tracingKeyPrefix) { delete(c, key) } } } ================================================ FILE: tracing_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "sync" "testing" "time" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/json" "github.com/uber/tchannel-go/testutils" "github.com/uber/tchannel-go/testutils/testtracing" "go.uber.org/atomic" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" ) // JSONHandler tests tracing over JSON encoding type JSONHandler struct { testtracing.TraceHandler t *testing.T sideEffect func(ctx json.Context) } func (h *JSONHandler) callJSON(ctx json.Context, req *testtracing.TracingRequest) (*testtracing.TracingResponse, error) { resp := new(testtracing.TracingResponse) resp.ObserveSpan(ctx) if h.sideEffect != nil { h.sideEffect(ctx) } return resp, nil } func (h *JSONHandler) onError(ctx context.Context, err error) { h.t.Errorf("onError %v", err) } func TestTracingSpanAttributes(t *testing.T) { tracer := mocktracer.New() opts := &testutils.ChannelOpts{ ChannelOptions: ChannelOptions{Tracer: tracer}, DisableRelay: true, } WithVerifiedServer(t, opts, func(ch *Channel, hostPort string) { const ( customAppHeaderKey = "futurama" customAppHeaderExpectedValue = "simpsons" ) var customAppHeaderValue atomic.String // Register JSON handler jsonHandler := &JSONHandler{ TraceHandler: testtracing.TraceHandler{Ch: ch}, t: t, sideEffect: func(ctx json.Context) { customAppHeaderValue.Store(ctx.Headers()[customAppHeaderKey]) }, } json.Register(ch, json.Handlers{"call": jsonHandler.callJSON}, jsonHandler.onError) span := ch.Tracer().StartSpan("client") span.SetBaggageItem(testtracing.BaggageKey, testtracing.BaggageValue) ctx := opentracing.ContextWithSpan(context.Background(), span) root := new(testtracing.TracingResponse).ObserveSpan(ctx) // Pretend that the client propagated tracing headers from upstream call, and test that the outbound call // will override them (https://github.com/uber/tchannel-go/issues/682). tracingHeaders := make(map[string]string) assert.NoError(t, tracer.Inject(span.Context(), opentracing.TextMap, opentracing.TextMapCarrier(tracingHeaders))) requestHeaders := map[string]string{ customAppHeaderKey: customAppHeaderExpectedValue, } for k := range tracingHeaders { requestHeaders["$tracing$"+k] = "garbage" } ctx, cancel := NewContextBuilder(2 * time.Second).SetParentContext(ctx).Build() defer cancel() peer := ch.Peers().GetOrAdd(ch.PeerInfo().HostPort) var response testtracing.TracingResponse require.NoError(t, json.CallPeer(json.WithHeaders(ctx, requestHeaders), peer, ch.PeerInfo().ServiceName, "call", &testtracing.TracingRequest{}, &response)) assert.Equal(t, customAppHeaderExpectedValue, customAppHeaderValue.Load(), "custom header was propagated") // Spans are finished in inbound.doneSending() or outbound.doneReading(), // which are called on different go-routines and may execute *after* the // response has been received by the client. Give them a chance to finish. for i := 0; i < 1000; i++ { if spanCount := len(testtracing.MockTracerSampledSpans(tracer)); spanCount == 2 { break } time.Sleep(testutils.Timeout(time.Millisecond)) } spans := testtracing.MockTracerSampledSpans(tracer) spanCount := len(spans) ch.Logger().Debugf("end span count: %d", spanCount) // finish span after taking count of recorded spans span.Finish() require.Equal(t, 2, spanCount, "Wrong span count") assert.Equal(t, root.TraceID, response.TraceID, "Trace ID must match root span") assert.Equal(t, testtracing.BaggageValue, response.Luggage, "Baggage must match") var parent, child *mocktracer.MockSpan for _, s := range spans { if s.Tag("span.kind") == ext.SpanKindRPCClientEnum { parent = s ch.Logger().Debugf("Found parent span: %+v", s) } else if s.Tag("span.kind") == ext.SpanKindRPCServerEnum { child = s ch.Logger().Debugf("Found child span: %+v", s) } } require.NotNil(t, parent) require.NotNil(t, child) traceID := func(s opentracing.Span) int { return s.Context().(mocktracer.MockSpanContext).TraceID } spanID := func(s *mocktracer.MockSpan) int { return s.Context().(mocktracer.MockSpanContext).SpanID } sampled := func(s *mocktracer.MockSpan) bool { return s.Context().(mocktracer.MockSpanContext).Sampled } require.Equal(t, traceID(span), traceID(parent), "parent must be found") require.Equal(t, traceID(span), traceID(child), "child must be found") assert.Equal(t, traceID(parent), traceID(child)) assert.Equal(t, spanID(parent), child.ParentID) assert.True(t, sampled(parent), "should be sampled") assert.True(t, sampled(child), "should be sampled") assert.Equal(t, "call", parent.OperationName) assert.Equal(t, "call", child.OperationName) assert.Equal(t, "testService", parent.Tag("peer.service")) assert.Equal(t, "testService", child.Tag("peer.service")) assert.Equal(t, "json", parent.Tag("as")) assert.Equal(t, "json", child.Tag("as")) assert.NotNil(t, parent.Tag("peer.ipv4")) assert.NotNil(t, child.Tag("peer.ipv4")) assert.NotNil(t, parent.Tag("peer.port")) assert.NotNil(t, child.Tag("peer.port")) assert.Equal(t, "tchannel-go", parent.Tag("component")) assert.Equal(t, "tchannel-go", child.Tag("component")) }) } // Per https://github.com/uber/tchannel-go/issues/505, concurrent client calls // made with the same shared map used as headers were causing panic due to // concurrent writes to the map when injecting tracing headers. func TestReusableHeaders(t *testing.T) { opts := &testutils.ChannelOpts{ ChannelOptions: ChannelOptions{Tracer: mocktracer.New()}, } WithVerifiedServer(t, opts, func(ch *Channel, hostPort string) { jsonHandler := &JSONHandler{TraceHandler: testtracing.TraceHandler{Ch: ch}, t: t} json.Register(ch, json.Handlers{"call": jsonHandler.callJSON}, jsonHandler.onError) span := ch.Tracer().StartSpan("client") traceID := span.(*mocktracer.MockSpan).SpanContext.TraceID // for validation ctx := opentracing.ContextWithSpan(context.Background(), span) sharedHeaders := map[string]string{"life": "42"} ctx, cancel := NewContextBuilder(2 * time.Second). SetHeaders(sharedHeaders). SetParentContext(ctx). Build() defer cancel() peer := ch.Peers().GetOrAdd(ch.PeerInfo().HostPort) var wg sync.WaitGroup for i := 0; i < 42; i++ { wg.Add(1) go func() { defer wg.Done() var response testtracing.TracingResponse err := json.CallPeer(json.Wrap(ctx), peer, ch.ServiceName(), "call", &testtracing.TracingRequest{}, &response) assert.NoError(t, err, "json.Call failed") assert.EqualValues(t, traceID, response.TraceID, "traceID must match") }() } wg.Wait() assert.Equal(t, map[string]string{"life": "42"}, sharedHeaders, "headers unchanged") }) } ================================================ FILE: trand/rand.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. // Package trand provides a thread-safe random number generator. package trand import ( "math/rand" "sync" "time" ) // lockedSource allows a random number generator to be used by multiple goroutines // concurrently. The code is very similar to math/rand.lockedSource, which is // unfortunately not exposed. type lockedSource struct { sync.Mutex src rand.Source } // New returns a rand.Rand that is threadsafe. func New(seed int64) *rand.Rand { return rand.New(&lockedSource{src: rand.NewSource(seed)}) } // NewSeeded returns a rand.Rand that's threadsafe and seeded with the current // time. func NewSeeded() *rand.Rand { return New(time.Now().UnixNano()) } func (r *lockedSource) Int63() (n int64) { r.Lock() n = r.src.Int63() r.Unlock() return } func (r *lockedSource) Seed(seed int64) { r.Lock() r.src.Seed(seed) r.Unlock() } ================================================ FILE: typed/buffer.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package typed import ( "encoding/binary" "errors" "io" ) var ( // ErrEOF is returned when trying to read past end of buffer ErrEOF = errors.New("buffer is too small") // ErrBufferFull is returned when trying to write past end of buffer ErrBufferFull = errors.New("no more room in buffer") // errStringTooLong is returned when writing a string with length larger // than the allows length limit. Intentionally not exported, in case we // want to add more context in future. errStringTooLong = errors.New("string is too long") ) // A ReadBuffer is a wrapper around an underlying []byte with methods to read from // that buffer in big-endian format. type ReadBuffer struct { initialLength int remaining []byte err error } // NewReadBuffer returns a ReadBuffer wrapping a byte slice func NewReadBuffer(buffer []byte) *ReadBuffer { return &ReadBuffer{initialLength: len(buffer), remaining: buffer} } // ReadSingleByte reads the next byte from the buffer func (r *ReadBuffer) ReadSingleByte() byte { b, _ := r.ReadByte() return b } // ReadByte returns the next byte from the buffer. // // This method implements the ByteReader interface. func (r *ReadBuffer) ReadByte() (byte, error) { if r.err != nil { return 0, r.err } if len(r.remaining) < 1 { r.err = ErrEOF return 0, r.err } b := r.remaining[0] r.remaining = r.remaining[1:] return b, nil } // ReadBytes returns the next n bytes from the buffer func (r *ReadBuffer) ReadBytes(n int) []byte { if r.err != nil { return nil } if len(r.remaining) < n { r.err = ErrEOF return nil } b := r.remaining[0:n] r.remaining = r.remaining[n:] return b } // SkipBytes skips the next n bytes from the buffer func (r *ReadBuffer) SkipBytes(n int) { if r.err != nil { return } if len(r.remaining) < n { r.err = ErrEOF return } r.remaining = r.remaining[n:] } // ReadString returns a string of size n from the buffer func (r *ReadBuffer) ReadString(n int) string { if b := r.ReadBytes(n); b != nil { // TODO(mmihic): This creates a copy, which sucks return string(b) } return "" } // ReadUint16 returns the next value in the buffer as a uint16 func (r *ReadBuffer) ReadUint16() uint16 { if b := r.ReadBytes(2); b != nil { return binary.BigEndian.Uint16(b) } return 0 } // ReadUint32 returns the next value in the buffer as a uint32 func (r *ReadBuffer) ReadUint32() uint32 { if b := r.ReadBytes(4); b != nil { return binary.BigEndian.Uint32(b) } return 0 } // ReadUint64 returns the next value in the buffer as a uint64 func (r *ReadBuffer) ReadUint64() uint64 { if b := r.ReadBytes(8); b != nil { return binary.BigEndian.Uint64(b) } return 0 } // ReadUvarint reads an unsigned varint from the buffer. func (r *ReadBuffer) ReadUvarint() uint64 { v, _ := binary.ReadUvarint(r) return v } // ReadLen8String reads an 8-bit length preceded string value func (r *ReadBuffer) ReadLen8String() string { n := r.ReadSingleByte() return r.ReadString(int(n)) } // ReadLen16String reads a 16-bit length preceded string value func (r *ReadBuffer) ReadLen16String() string { n := r.ReadUint16() return r.ReadString(int(n)) } // Remaining returns the unconsumed bytes. func (r *ReadBuffer) Remaining() []byte { return r.remaining } // BytesRemaining returns the length of Remaining. func (r *ReadBuffer) BytesRemaining() int { return len(r.Remaining()) } // BytesRead returns the number of bytes consumed func (r *ReadBuffer) BytesRead() int { return r.initialLength - len(r.remaining) } // Wrap initializes the buffer to read from the given byte slice func (r *ReadBuffer) Wrap(b []byte) { r.initialLength = len(b) r.remaining = b r.err = nil } // Err returns the error in the ReadBuffer func (r *ReadBuffer) Err() error { return r.err } // A WriteBuffer is a wrapper around an underlying []byte with methods to write to // that buffer in big-endian format. The buffer is of fixed size, and does not grow. type WriteBuffer struct { buffer []byte remaining []byte err error } // NewWriteBuffer creates a WriteBuffer wrapping the given slice func NewWriteBuffer(buffer []byte) *WriteBuffer { return &WriteBuffer{buffer: buffer, remaining: buffer} } // NewWriteBufferWithSize create a new WriteBuffer using an internal buffer of the given size func NewWriteBufferWithSize(size int) *WriteBuffer { return NewWriteBuffer(make([]byte, size)) } // WriteSingleByte writes a single byte to the buffer func (w *WriteBuffer) WriteSingleByte(n byte) { if w.err != nil { return } if len(w.remaining) == 0 { w.setErr(ErrBufferFull) return } w.remaining[0] = n w.remaining = w.remaining[1:] } // WriteBytes writes a slice of bytes to the buffer func (w *WriteBuffer) WriteBytes(in []byte) { if b := w.reserve(len(in)); b != nil { copy(b, in) } } // WriteUint16 writes a big endian encoded uint16 value to the buffer func (w *WriteBuffer) WriteUint16(n uint16) { if b := w.reserve(2); b != nil { binary.BigEndian.PutUint16(b, n) } } // WriteUint32 writes a big endian uint32 value to the buffer func (w *WriteBuffer) WriteUint32(n uint32) { if b := w.reserve(4); b != nil { binary.BigEndian.PutUint32(b, n) } } // WriteUint64 writes a big endian uint64 to the buffer func (w *WriteBuffer) WriteUint64(n uint64) { if b := w.reserve(8); b != nil { binary.BigEndian.PutUint64(b, n) } } // WriteUvarint writes an unsigned varint to the buffer func (w *WriteBuffer) WriteUvarint(n uint64) { // A uvarint could be up to 10 bytes long. buf := make([]byte, 10) varBytes := binary.PutUvarint(buf, n) if b := w.reserve(varBytes); b != nil { copy(b, buf[0:varBytes]) } } // WriteString writes a string to the buffer func (w *WriteBuffer) WriteString(s string) { // NB(mmihic): Don't just call WriteBytes; that will make a double copy // of the string due to the cast if b := w.reserve(len(s)); b != nil { copy(b, s) } } // WriteLen8String writes an 8-bit length preceded string func (w *WriteBuffer) WriteLen8String(s string) { if int(byte(len(s))) != len(s) { w.setErr(errStringTooLong) } w.WriteSingleByte(byte(len(s))) w.WriteString(s) } // WriteLen16String writes a 16-bit length preceded string func (w *WriteBuffer) WriteLen16String(s string) { if int(uint16(len(s))) != len(s) { w.setErr(errStringTooLong) } w.WriteUint16(uint16(len(s))) w.WriteString(s) } // DeferByte reserves space in the buffer for a single byte, and returns a // reference that can be used to update that byte later func (w *WriteBuffer) DeferByte() ByteRef { if len(w.remaining) == 0 { w.setErr(ErrBufferFull) return ByteRef(nil) } // Always zero out references, since the caller expects the default to be 0. w.remaining[0] = 0 bufRef := ByteRef(w.remaining[0:]) w.remaining = w.remaining[1:] return bufRef } // DeferUint16 reserves space in the buffer for a uint16, and returns a // reference that can be used to update that uint16 func (w *WriteBuffer) DeferUint16() Uint16Ref { return Uint16Ref(w.deferred(2)) } // DeferUint32 reserves space in the buffer for a uint32, and returns a // reference that can be used to update that uint32 func (w *WriteBuffer) DeferUint32() Uint32Ref { return Uint32Ref(w.deferred(4)) } // DeferUint64 reserves space in the buffer for a uint64, and returns a // reference that can be used to update that uint64 func (w *WriteBuffer) DeferUint64() Uint64Ref { return Uint64Ref(w.deferred(8)) } // DeferBytes reserves space in the buffer for a fixed sequence of bytes, and // returns a reference that can be used to update those bytes func (w *WriteBuffer) DeferBytes(n int) BytesRef { return BytesRef(w.deferred(n)) } func (w *WriteBuffer) deferred(n int) []byte { bs := w.reserve(n) for i := range bs { bs[i] = 0 } return bs } func (w *WriteBuffer) reserve(n int) []byte { if w.err != nil { return nil } if len(w.remaining) < n { w.setErr(ErrBufferFull) return nil } b := w.remaining[0:n] w.remaining = w.remaining[n:] return b } // BytesRemaining returns the number of available bytes remaining in the bufffer func (w *WriteBuffer) BytesRemaining() int { return len(w.remaining) } // FlushTo flushes the written buffer to the given writer. func (w *WriteBuffer) FlushTo(iow io.Writer) (int, error) { dirty := w.buffer[0:w.BytesWritten()] return iow.Write(dirty) } // BytesWritten returns the number of bytes that have been written to the buffer func (w *WriteBuffer) BytesWritten() int { return len(w.buffer) - len(w.remaining) } // Reset resets the buffer to an empty state, ready for writing func (w *WriteBuffer) Reset() { w.remaining = w.buffer w.err = nil } func (w *WriteBuffer) setErr(err error) { // Only store the first error if w.err != nil { return } w.err = err } // Err returns the current error in the buffer func (w *WriteBuffer) Err() error { return w.err } // Wrap initializes the buffer to wrap the given byte slice func (w *WriteBuffer) Wrap(b []byte) { w.buffer = b w.remaining = b } // A ByteRef is a reference to a byte in a bufffer type ByteRef []byte // Update updates the byte in the buffer func (ref ByteRef) Update(b byte) { if ref != nil { ref[0] = b } } // A Uint16Ref is a reference to a uint16 placeholder in a buffer type Uint16Ref []byte // Update updates the uint16 in the buffer func (ref Uint16Ref) Update(n uint16) { if ref != nil { binary.BigEndian.PutUint16(ref, n) } } // A Uint32Ref is a reference to a uint32 placeholder in a buffer type Uint32Ref []byte // Update updates the uint32 in the buffer func (ref Uint32Ref) Update(n uint32) { if ref != nil { binary.BigEndian.PutUint32(ref, n) } } // A Uint64Ref is a reference to a uin64 placeholder in a buffer type Uint64Ref []byte // Update updates the uint64 in the buffer func (ref Uint64Ref) Update(n uint64) { if ref != nil { binary.BigEndian.PutUint64(ref, n) } } // A BytesRef is a reference to a multi-byte placeholder in a buffer type BytesRef []byte // Update updates the bytes in the buffer func (ref BytesRef) Update(b []byte) { if ref != nil { copy(ref, b) } } // UpdateString updates the bytes in the buffer from a string func (ref BytesRef) UpdateString(s string) { if ref != nil { copy(ref, s) } } ================================================ FILE: typed/buffer_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package typed import ( "bytes" "errors" "math" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSimple(t *testing.T) { buf := make([]byte, 200) var r ReadBuffer var w WriteBuffer { w.Wrap(buf) w.WriteSingleByte(0xFC) r.Wrap(buf) assert.Equal(t, byte(0xFC), r.ReadSingleByte()) } { w.Wrap(buf) w.WriteUint16(0xDEAD) r.Wrap(buf) assert.Equal(t, uint16(0xDEAD), r.ReadUint16()) } { w.Wrap(buf) w.WriteUint32(0xBEEFDEAD) r.Wrap(buf) assert.Equal(t, uint32(0xBEEFDEAD), r.ReadUint32()) } } func TestReadBufferSkipBytes(t *testing.T) { exampleBytes := make([]byte, 128) tests := []struct { msg string buf *ReadBuffer nSkip int wantError string wantRead int wantRemaining int }{ { msg: "successful skip", buf: NewReadBuffer(exampleBytes), nSkip: 64, wantRead: 64, wantRemaining: 64, }, { msg: "error occurred prior to skip", buf: func() *ReadBuffer { buf := NewReadBuffer(exampleBytes) buf.err = errors.New("something bad happened") return buf }(), nSkip: 64, wantError: "something bad happened", wantRead: 0, wantRemaining: 128, }, { msg: "not enough bytes remain", buf: func() *ReadBuffer { buf := NewReadBuffer(exampleBytes) return buf }(), nSkip: 256, wantError: "buffer is too small", wantRead: 0, wantRemaining: 128, }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { tt.buf.SkipBytes(tt.nSkip) if tt.wantError != "" { require.EqualError(t, tt.buf.Err(), tt.wantError, "Didn't get exepcted error") } assert.Equal(t, tt.wantRead, tt.buf.BytesRead()) assert.Equal(t, tt.wantRemaining, tt.buf.BytesRemaining()) }) } } func TestShortBuffer(t *testing.T) { r := NewReadBuffer([]byte{23}) assert.EqualValues(t, 0, r.ReadUint16()) assert.Equal(t, ErrEOF, r.Err()) } func TestReadWrite(t *testing.T) { s := "the small brown fix" bslice := []byte("jumped over the lazy dog") w := NewWriteBufferWithSize(1024) w.WriteUint64(0x0123456789ABCDEF) w.WriteUint32(0xABCDEF01) w.WriteUint16(0x2345) w.WriteUvarint(1) w.WriteUvarint(math.MaxInt16) w.WriteUvarint(math.MaxInt32) w.WriteUvarint(math.MaxInt64) w.WriteSingleByte(0xFF) w.WriteString(s) w.WriteBytes(bslice) w.WriteLen8String("hello") w.WriteLen16String("This is a much larger string") require.NoError(t, w.Err()) } func TestReadBufferTracking(t *testing.T) { bs := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} // Run twice, once on the original buffer, and once after wrapping. rbuf := NewReadBuffer(bs) for _, wrap := range []bool{false, true} { if wrap { rbuf.Wrap(bs) } t.Run("nothing read", func(t *testing.T) { assert.Equal(t, len(bs), rbuf.BytesRemaining(), "BytesRemaining") assert.Equal(t, bs, rbuf.Remaining(), "Remaining") assert.Zero(t, rbuf.BytesRead(), "BytesRead") }) t.Run("partially consumed", func(t *testing.T) { rbuf.ReadByte() rbuf.ReadUint32() assert.Equal(t, 5, rbuf.BytesRemaining(), "BytesRemaining") assert.Equal(t, bs[5:], rbuf.Remaining(), "Remaining") assert.Equal(t, 5, rbuf.BytesRead(), "BytesRead") }) t.Run("fully consumed", func(t *testing.T) { rbuf.ReadByte() rbuf.ReadUint32() assert.Zero(t, rbuf.BytesRemaining(), "BytesRemaining") assert.Empty(t, rbuf.Remaining(), "Remaining") assert.Equal(t, len(bs), rbuf.BytesRead(), "BytesRead") }) require.NoError(t, rbuf.Err()) } } func TestDeferredWrites(t *testing.T) { w := NewWriteBufferWithSize(1024) u16ref := w.DeferUint16() require.NotNil(t, u16ref) u32ref := w.DeferUint32() require.NotNil(t, u32ref) u64ref := w.DeferUint64() require.NotNil(t, u64ref) bref := w.DeferBytes(5) require.NotNil(t, bref) sref := w.DeferBytes(5) require.NotNil(t, sref) byteref := w.DeferByte() require.NotNil(t, byteref) assert.Equal(t, 2+4+8+5+5+1, w.BytesWritten()) u16ref.Update(2040) u32ref.Update(495404) u64ref.Update(0x40950459) bref.Update([]byte{0x30, 0x12, 0x45, 0x55, 0x65}) sref.UpdateString("where") byteref.Update(0x44) var buf bytes.Buffer w.FlushTo(&buf) r := NewReadBuffer(buf.Bytes()) u16 := r.ReadUint16() assert.Equal(t, uint16(2040), u16) u32 := r.ReadUint32() assert.Equal(t, uint32(495404), u32) u64 := r.ReadUint64() assert.Equal(t, uint64(0x40950459), u64) b := r.ReadBytes(5) assert.Equal(t, []byte{0x30, 0x12, 0x45, 0x55, 0x65}, b) s := r.ReadString(5) assert.Equal(t, "where", s) u8 := r.ReadSingleByte() assert.Equal(t, byte(0x44), u8) assert.NoError(t, r.Err()) } func TestDirtyUnderlyingBuffer(t *testing.T) { buf := make([]byte, 128) for i := range buf { buf[i] = ^byte(0) } w := NewWriteBuffer(buf) // Defer 1 + 2 + 4 + 8 + 5 = 20 bytes w.DeferByte() w.DeferUint16() w.DeferUint32() w.DeferUint64() w.DeferBytes(5) defer1 := w.DeferByte() defer2 := w.DeferUint16() defer3 := w.DeferUint32() defer4 := w.DeferUint64() defer5 := w.DeferBytes(5) w.WriteUint16(16) w.WriteUint32(32) w.WriteUint64(64) w.WriteLen16String("len16 string") w.WriteLen8String("len8 string") w.WriteString("string") w.WriteSingleByte(1) w.WriteBytes([]byte{1, 2, 3, 4, 5}) defer1.Update(11) defer2.Update(116) defer3.Update(132) defer4.Update(164) defer5.Update([]byte{11, 12, 13, 14, 15}) r := NewReadBuffer(buf) // Deferred unwritten bytes should be 0. assert.EqualValues(t, 0, r.ReadSingleByte(), "unwritten deferred should be 0") assert.EqualValues(t, 0, r.ReadUint16(), "unwritten deferred should be 0") assert.EqualValues(t, 0, r.ReadUint32(), "unwritten deferred should be 0") assert.EqualValues(t, 0, r.ReadUint64(), "unwritten deferred should be 0") assert.Equal(t, []byte{0, 0, 0, 0, 0}, r.ReadBytes(5), "unwritten deferred should be 0") // Deferred written bytes. assert.EqualValues(t, 11, r.ReadSingleByte(), "defer byte") assert.EqualValues(t, 116, r.ReadUint16(), "defer uint16") assert.EqualValues(t, 132, r.ReadUint32(), "defer uint32") assert.EqualValues(t, 164, r.ReadUint64(), "defer uint64") assert.Equal(t, []byte{11, 12, 13, 14, 15}, r.ReadBytes(5), "defer bytes") // Normally written bytes. assert.EqualValues(t, 16, r.ReadUint16(), "uint16") assert.EqualValues(t, 32, r.ReadUint32(), "uint32") assert.EqualValues(t, 64, r.ReadUint64(), "uint64") assert.Equal(t, "len16 string", r.ReadLen16String(), "len16 string") assert.Equal(t, "len8 string", r.ReadLen8String(), "len 8 string") assert.Equal(t, "string", r.ReadString(6), "string") assert.EqualValues(t, 1, r.ReadSingleByte(), "byte") assert.Equal(t, []byte{1, 2, 3, 4, 5}, r.ReadBytes(5), "bytes") } ================================================ FILE: typed/reader.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package typed import ( "encoding/binary" "io" "sync" ) const maxPoolStringLen = 32 // Reader is a reader that reads typed values from an io.Reader. type Reader struct { reader io.Reader err error buf [maxPoolStringLen]byte } var readerPool = sync.Pool{ New: func() interface{} { return &Reader{} }, } // NewReader returns a reader that reads typed values from the reader. func NewReader(reader io.Reader) *Reader { r := readerPool.Get().(*Reader) r.reader = reader r.err = nil return r } // ReadUint16 reads a uint16. func (r *Reader) ReadUint16() uint16 { if r.err != nil { return 0 } buf := r.buf[:2] var readN int readN, r.err = io.ReadFull(r.reader, buf) if readN < 2 { return 0 } return binary.BigEndian.Uint16(buf) } // ReadString reads a string of length n. func (r *Reader) ReadString(n int) string { if r.err != nil { return "" } var buf []byte if n <= maxPoolStringLen { buf = r.buf[:n] } else { buf = make([]byte, n) } var readN int readN, r.err = io.ReadFull(r.reader, buf) if readN < n { return "" } s := string(buf) return s } // ReadLen16String reads a uint16-length prefixed string. func (r *Reader) ReadLen16String() string { len := r.ReadUint16() return r.ReadString(int(len)) } // Err returns any errors hit while reading from the underlying reader. func (r *Reader) Err() error { return r.err } // Release puts the Reader back in the pool. func (r *Reader) Release() { readerPool.Put(r) } ================================================ FILE: typed/reader_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package typed import ( "bytes" "io" "testing" "github.com/stretchr/testify/assert" "github.com/uber/tchannel-go/testutils/testreader" ) func nString(n int) []byte { buf := make([]byte, n) reader := testreader.Looper([]byte{'a', 'b', 'c', 'd', 'e'}) io.ReadFull(reader, buf) return buf } func TestReader(t *testing.T) { s1 := nString(10) s2 := nString(800) var buf []byte buf = append(buf, 0, 1) // uint16, 1 buf = append(buf, 0xff, 0xff) // uint16, 65535 buf = append(buf, 0, 10) // uint16, 10 buf = append(buf, s1...) // string, 10 bytes buf = append(buf, 3, 32) // uint16, 800 buf = append(buf, s2...) // string, 800 bytes buf = append(buf, 0, 10) // uint16, 10 reader := NewReader(bytes.NewReader(buf)) assert.Equal(t, uint16(1), reader.ReadUint16()) assert.Equal(t, uint16(65535), reader.ReadUint16()) assert.Equal(t, string(s1), reader.ReadLen16String()) assert.Equal(t, string(s2), reader.ReadLen16String()) assert.Equal(t, uint16(10), reader.ReadUint16()) } func TestReaderErr(t *testing.T) { tests := []struct { chunks [][]byte validation func(reader *Reader) }{ { chunks: [][]byte{ {0, 1}, nil, {2, 3}, }, validation: func(reader *Reader) { assert.Equal(t, uint16(1), reader.ReadUint16(), "Read unexpected value") assert.Equal(t, uint16(0), reader.ReadUint16(), "Expected default value") }, }, { chunks: [][]byte{ {0, 4}, []byte("test"), nil, {'A', 'b'}, }, validation: func(reader *Reader) { assert.Equal(t, "test", reader.ReadLen16String(), "Read unexpected value") assert.Equal(t, "", reader.ReadString(2), "Expected default value") }, }, } for _, tt := range tests { writer, chunkReader := testreader.ChunkReader() reader := NewReader(chunkReader) defer reader.Release() for _, chunk := range tt.chunks { writer <- chunk } close(writer) tt.validation(reader) // Once there's an error, all further calls should fail. assert.Equal(t, testreader.ErrUser, reader.Err(), "Unexpected error") assert.Equal(t, uint16(0), reader.ReadUint16(), "Expected default value") assert.Equal(t, "", reader.ReadString(1), "Expected default value") assert.Equal(t, "", reader.ReadLen16String(), "Expected default value") assert.Equal(t, testreader.ErrUser, reader.Err(), "Unexpected error") } } ================================================ FILE: typed/writer.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package typed import ( "encoding/binary" "io" "sync" ) type intBuffer [8]byte var intBufferPool = sync.Pool{New: func() interface{} { return new(intBuffer) }} // Writer is a writer that writes typed values to an io.Writer type Writer struct { writer io.Writer err error } // NewWriter creates a writer that writes typed value to a reader func NewWriter(w io.Writer) *Writer { return &Writer{ writer: w, } } // WriteBytes writes a slice of bytes to the io.Writer func (w *Writer) WriteBytes(b []byte) { if w.err != nil { return } if _, err := w.writer.Write(b); err != nil { w.err = err } } // WriteUint16 writes a uint16 to the io.Writer func (w *Writer) WriteUint16(n uint16) { if w.err != nil { return } sizeBuf := intBufferPool.Get().(*intBuffer) defer intBufferPool.Put(sizeBuf) binary.BigEndian.PutUint16(sizeBuf[:2], n) if _, err := w.writer.Write(sizeBuf[:2]); err != nil { w.err = err } } // WriteLen16Bytes writes a slice of bytes to the io.Writer preceded with // the length of the slice func (w *Writer) WriteLen16Bytes(b []byte) { if w.err != nil { return } w.WriteUint16(uint16(len(b))) w.WriteBytes(b) } // Err returns the error state of the writer func (w *Writer) Err() error { return w.err } ================================================ FILE: typed/writer_test.go ================================================ package typed import ( "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type dummyWriter struct { calls int bytesWritten []byte // retError is a map of call ids to error strings retError map[int]string } func (w *dummyWriter) Write(b []byte) (int, error) { defer func() { w.calls++ }() if w.retError[w.calls] != "" { return 0, errors.New(w.retError[w.calls]) } w.bytesWritten = append(w.bytesWritten, b...) return len(b), nil } func TestWriter(t *testing.T) { tests := []struct { msg string w *dummyWriter previousError error wantError string wantBytesWritten []byte }{ { msg: "successful write", w: &dummyWriter{ retError: map[int]string{}, }, wantBytesWritten: []byte{0, 1, 2, 0, 3, 4, 5, 6}, }, { msg: "return error due to previous error", previousError: errors.New("something went wrong previously"), w: &dummyWriter{}, wantError: "something went wrong previously", }, { msg: "error writing length", w: &dummyWriter{ retError: map[int]string{0: "something went wrong"}, }, wantError: "something went wrong", }, { msg: "error writing data", w: &dummyWriter{ retError: map[int]string{1: "something went wrong"}, }, wantError: "something went wrong", }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { writes := func(w *Writer) { w.WriteUint16(1) w.WriteBytes([]byte{2}) w.WriteLen16Bytes([]byte{4, 5, 6}) } w := NewWriter(tt.w) w.err = tt.previousError writes(w) if tt.wantError != "" { require.EqualError(t, w.Err(), tt.wantError, "Got unexpected error") return } require.NoError(t, w.Err(), "Got unexpected error") assert.Equal(t, tt.wantBytesWritten, tt.w.bytesWritten) }) } } ================================================ FILE: utils_for_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel // This file contains functions for tests to access internal tchannel state. // Since it has a _test.go suffix, it is only compiled with tests in this package. import ( "net" "golang.org/x/net/context" ) // MexChannelBufferSize is the size of the message exchange channel buffer. const MexChannelBufferSize = mexChannelBufferSize // SetOnUpdate sets onUpdate for a peer, which is called when the peer's score is // updated in all peer lists. func (p *Peer) SetOnUpdate(f func(*Peer)) { p.Lock() p.onUpdate = f p.Unlock() } // SetRandomSeed seeds all the random number generators in the channel so that // tests will be deterministic for a given seed. func (ch *Channel) SetRandomSeed(seed int64) { ch.Peers().peerHeap.rng.Seed(seed) peerRng.Seed(seed) for _, sc := range ch.subChannels.subchannels { sc.peers.peerHeap.rng.Seed(seed + int64(len(sc.peers.peersByHostPort))) } } // Ping exports ping for tests. func (c *Connection) Ping(ctx context.Context) error { return c.ping(ctx) } // Logger returns the logger for the specific connection for tests. func (c *Connection) Logger() Logger { return c.log } // StopHealthCheck exports stopHealthCheck for tests. func (c *Connection) StopHealthCheck() { c.stopHealthCheck() } // OutboundConnection returns the underlying connection for an outbound call. func OutboundConnection(call *OutboundCall) (*Connection, net.Conn) { conn := call.conn return conn, conn.conn } // InboundConnection returns the underlying connection for an incoming call. func InboundConnection(call IncomingCall) (*Connection, net.Conn) { inboundCall, ok := call.(*InboundCall) if !ok { return nil, nil } conn := inboundCall.conn return conn, inboundCall.Connection() } ================================================ FILE: verify_utils_test.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel_test import ( "testing" . "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/testutils" ) // WithVerifiedServer runs the given test function with a server channel that is verified // at the end to make sure there are no leaks (e.g. no exchanges leaked). func WithVerifiedServer(t *testing.T, opts *testutils.ChannelOpts, f func(serverCh *Channel, hostPort string)) { testutils.WithTestServer(t, opts, func(t testing.TB, ts *testutils.TestServer) { f(ts.Server(), ts.HostPort()) }) } ================================================ FILE: version.go ================================================ // Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package tchannel // VersionInfo identifies the version of the TChannel library. // Due to lack of proper package management, this version string will // be maintained manually. const VersionInfo = "1.34.6-dev"