Repository: d5h-foss/grpc-interceptor Branch: master Commit: d866a1bd09e9 Files: 42 Total size: 132.5 KB Directory structure: gitextract_8jb1zy9l/ ├── .flake8 ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── something-else.md │ ├── release-drafter.yml │ └── workflows/ │ ├── coverage.yml │ ├── mindeps.yml │ ├── release-drafter.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .readthedocs.yml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── docs/ │ ├── conf.py │ ├── index.rst │ ├── license.rst │ ├── reference.rst │ └── requirements.txt ├── mypy.ini ├── noxfile.py ├── pyproject.toml ├── src/ │ └── grpc_interceptor/ │ ├── __init__.py │ ├── client.py │ ├── exception_to_status.py │ ├── exceptions.py │ ├── py.typed │ ├── server.py │ └── testing/ │ ├── __init__.py │ ├── dummy_client.py │ └── protos/ │ ├── __init__.py │ ├── dummy.proto │ ├── dummy_pb2.py │ ├── dummy_pb2.pyi │ └── dummy_pb2_grpc.py └── tests/ ├── __init__.py ├── test_client.py ├── test_exception_to_status.py ├── test_server.py └── test_streaming.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .flake8 ================================================ [flake8] select = B,B9,C,D,E,F,I,S,W exclude = *_pb2.py,*_pb2_grpc.py ignore = D107,W503 application-import-names = grpc_interceptor,tests import-order-style = google max-complexity = 10 max-line-length = 88 # asserts in tests are OK per-file-ignores = tests/*:S101 docstring-convention = google ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- ### What versions of the following are you using? * `python`: * `grpc-interceptor`: * `grpcio`: * `protobuf`: ### What operating system (Linux, Windows,...) and version? ### What did you do? Please provide specific steps to reproduce the bug. ### What did you expect to see? ### What did you see instead? Please include any information that's helpful for debugging (full error message, exception listing, stack trace, logs). ### Anything else we should know about your project / environment? ================================================ FILE: .github/ISSUE_TEMPLATE/something-else.md ================================================ --- name: Something else about: Feature requests and other things title: '' labels: '' assignees: '' --- ================================================ FILE: .github/release-drafter.yml ================================================ categories: - title: ':boom: Breaking Changes' label: 'breaking' - title: ':package: Build System' label: 'build' - title: ':construction_worker: Continuous Integration' label: 'ci' - title: ':books: Documentation' label: 'documentation' - title: ':rocket: Features' label: 'enhancement' - title: ':beetle: Fixes' label: 'bug' - title: ':racehorse: Performance' label: 'performance' - title: ':hammer: Refactoring' label: 'refactoring' - title: ':fire: Removals and Deprecations' label: 'removal' - title: ':lipstick: Style' label: 'style' - title: ':rotating_light: Testing' label: 'testing' template: | ## What’s Changed $CHANGES ================================================ FILE: .github/workflows/coverage.yml ================================================ name: Coverage on: [push, pull_request] jobs: coverage: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v4 with: python-version: '3.11' architecture: x64 - run: pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9 - run: nox --sessions tests-3.11 - uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} ================================================ FILE: .github/workflows/mindeps.yml ================================================ name: Minimum Dependencies on: [push, pull_request] jobs: mindeps: runs-on: ubuntu-latest container: python:3.7-slim steps: - name: Installing dependencies run: | pip install --upgrade pip && pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9 - uses: actions/checkout@v2 - run: | cd "$GITHUB_WORKSPACE" && mkdir .nox && nox --sessions mindeps ================================================ FILE: .github/workflows/release-drafter.yml ================================================ name: Release Drafter on: push: branches: - master jobs: draft_release: runs-on: ubuntu-latest steps: - uses: release-drafter/release-drafter@v5.6.1 env: GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: release: types: [published] jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v4 with: python-version: '3.9' architecture: x64 - run: pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9 - run: nox - run: poetry build - uses: actions/upload-artifact@v3 with: name: dist path: dist/ release: needs: build runs-on: ubuntu-latest permissions: id-token: write environment: name: pypi url: https://pypi.org/p/grpc-interceptor steps: - uses: actions/download-artifact@v3 with: name: dist path: dist - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 ================================================ FILE: .github/workflows/tests.yml ================================================ name: Tests on: [push, pull_request] jobs: tests: strategy: fail-fast: false matrix: platform: [ubuntu-latest, macos-latest, windows-latest] python-version: ['3.11', '3.10', '3.9', '3.8', '3.7'] name: Python ${{matrix.python-version}} ${{matrix.platform}} runs-on: ${{matrix.platform}} steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v4 with: python-version: ${{matrix.python-version}} architecture: x64 - run: pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9 - run: nox --python ${{matrix.python-version}} ================================================ FILE: .gitignore ================================================ /.coverage /.idea/ /.nox/ /.venv/ /.vscode/ /coverage.xml /dist/ /docs/_build/ /src/*.egg-info/ __pycache__/ poetry.toml ================================================ FILE: .readthedocs.yml ================================================ version: 2 sphinx: configuration: docs/conf.py formats: all python: version: 3.8 install: - requirements: docs/requirements.txt - path: . extra_requirements: [testing] ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based roughly on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.13.1] - 2022-01-29 ### Added - Run tests on Python 3.9 and default to that for linting, etc. ### Fixed - Raise UNIMPLEMENTED instead of UNKNOWN when no handler is found for a method (thanks Richard Mahlberg!) ## [0.13.0] - 2020-12-27 ### Added - Client-side interceptors (thanks Michael Morgan!) ### Changed (breaking) - Added a `context` parameter to special case functions in the `testing` module ### Fixed - Build issue caused by [pip upgrade](https://github.com/cjolowicz/hypermodern-python/issues/174#issuecomment-745364836) - Docs not building correctly in nox ## [0.12.0] - 2020-10-07 ### Added - Support for all streaming RPCs ## [0.11.0] - 2020-07-24 ### Added - Expose some imports from the top-level ### Changed (breaking) - Rename to `ServerInterceptor` (do not intend to make breaking name changes after this) ## [0.10.0] - 2020-07-23 ### Added - `status_on_unknown_exception` to `ExceptionToStatusInterceptor` - `py.typed` (so `mypy` will type check the package) ### Fixed - Allow protobuf version 4.0.x which is coming out soon and is backwards compatible - Testing in autodocs - Turn on xdoctest - Prevent autodoc from outputting default namedtuple docs ### Changed (breaking) - Rename `Interceptor` to `ServiceInterceptor` ## [0.9.0] - 2020-07-22 ### Added - The `testing` module - Some helper functions - Improved test coverage ### Fixed - Protobuf compatibility improvements ## [0.8.0] - 2020-07-19 ### Added - An `Interceptor` base class, to make it easy to define your own service interceptors - An `ExceptionToStatusInterceptor` interceptor ================================================ FILE: CONTRIBUTING.md ================================================ # Running Tests This will run the unit tests quickly: ``` poetry install make tests ``` It doesn't run the entire test suite. See below for that. # Making a Pull Request Please bump the version number in `pyproject.toml` when you make a pull request. This is needed to give the package a new version. Also run lint checks and mypy before pushing. This runs in Github Actions as well, but you'll get faster feedback by running it locally. To do this, run `nox -s lint` and `nox -s mypy-3.x`, for whatever version of Python you have installed. For example, if you're using Python 3.9, run `nox -s mypy-3.9`. If you need to make formatting changes, you can run `nox -s black`. Note that `nox` isn't installed via `poetry`, due to the way it works, so you'll need to install it globally. If you don't want to install `nox`, you can do all this in docker. For example, you can run this to mount the current directory into /app and work from there: ``` docker run --rm -it --mount type=bind,src="$(pwd)",dst=/app python:3.9 bash ``` # Adding Tests Add both a sync and async version if applicable. You can follow the examples in many tests for this. Search for `aio` to find one. Assuming the test applies to both sync and async code, it will need to create a different test interceptor depending on the value of `aio`. Then just remember to pass `aio_server=aio` to `dummy_client`. The rest of the test can be the same for both sync and async. This is because the tests create a client which calls a server. The client doesn't care whether the server is sync or async. # On Changing Dependencies I want to keep this library very small, not just in terms of its own code, but in terms of the code it pulls in. Having many dependencies is a burden to users. It increases installation time (especially when solving constraints with newer pip or poetry). It increases the likelihood of dependency conflicts, and generally just introduces more that can go wrong. Hence, this library depends on as little as possible, and can hopefully stay that way. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Dan Hipschman Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ TEST_PROTOS := src/grpc_interceptor/testing/protos/dummy.proto TEST_PROTO_GEN := $(shell echo $(TEST_PROTOS) | sed 's/\.proto/_pb2.py/g') \ $(shell echo $(TEST_PROTOS) | sed 's/\.proto/_pb2_grpc.py/g') $(TEST_PROTO_GEN): $(TEST_PROTOS) cd src && \ printf "%s\n" $(TEST_PROTOS) | \ sed 's|^src/||' | \ xargs poetry run python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. --mypy_out=. .PHONY: test test: $(TEST_PROTO_GEN) poetry run pytest --cov .PHONY: nox-test nox-test: $(TEST_PROTO_GEN) nox -r ================================================ FILE: README.md ================================================ [![Tests](https://github.com/d5h-foss/grpc-interceptor/workflows/Tests/badge.svg)](https://github.com/d5h-foss/grpc-interceptor/actions?workflow=Tests) [![Codecov](https://codecov.io/gh/d5h-foss/grpc-interceptor/branch/master/graph/badge.svg)](https://codecov.io/gh/d5h-foss/grpc-interceptor) [![Read the Docs](https://readthedocs.org/projects/grpc-interceptor/badge/)](https://grpc-interceptor.readthedocs.io/) [![PyPI](https://img.shields.io/pypi/v/grpc-interceptor.svg)](https://pypi.org/project/grpc-interceptor/) # Summary Simplified Python gRPC interceptors. The Python `grpc` package provides service interceptors, but they're a bit hard to use because of their flexibility. The `grpc` interceptors don't have direct access to the request and response objects, or the service context. Access to these are often desired, to be able to log data in the request or response, or set status codes on the context. # Installation To just get the interceptors (and probably not write your own): ```console $ pip install grpc-interceptor ``` To also get the testing framework, which is good if you're writing your own interceptors: ```console $ pip install grpc-interceptor[testing] ``` # Usage ## Server Interceptor To define your own interceptor (we can use `ExceptionToStatusInterceptor` as an example): ```python from grpc_interceptor import ServerInterceptor from grpc_interceptor.exceptions import GrpcException class ExceptionToStatusInterceptor(ServerInterceptor): def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, ) -> Any: """Override this method to implement a custom interceptor. You should call method(request_or_iterator, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: The next interceptor, or method implementation. request_or_iterator: The RPC request, as a protobuf message. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should generally return the result of method(request_or_iterator, context), which is typically the RPC method response, as a protobuf message. The interceptor is free to modify this in some way, however. """ try: return method(request_or_iterator, context) except GrpcException as e: context.set_code(e.status_code) context.set_details(e.details) raise ``` Then inject your interceptor when you create the `grpc` server: ```python interceptors = [ExceptionToStatusInterceptor()] server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), interceptors=interceptors ) ``` To use `ExceptionToStatusInterceptor`: ```python from grpc_interceptor.exceptions import NotFound class MyService(my_pb2_grpc.MyServiceServicer): def MyRpcMethod( self, request: MyRequest, context: grpc.ServicerContext ) -> MyResponse: thing = lookup_thing() if not thing: raise NotFound("Sorry, your thing is missing") ... ``` This results in the gRPC status status code being set to `NOT_FOUND`, and the details `"Sorry, your thing is missing"`. This saves you the hassle of catching exceptions in your service handler, or passing the context down into helper functions so they can call `context.abort` or `context.set_code`. It allows the more Pythonic approach of just raising an exception from anywhere in the code, and having it be handled automatically. ## Client Interceptor We will use an invocation metadata injecting interceptor as an example of defining a client interceptor: ```python from grpc_interceptor import ClientCallDetails, ClientInterceptor class MetadataClientInterceptor(ClientInterceptor): def intercept( self, method: Callable, request_or_iterator: Any, call_details: grpc.ClientCallDetails, ): """Override this method to implement a custom interceptor. This method is called for all unary and streaming RPCs. The interceptor implementation should call `method` using a `grpc.ClientCallDetails` and the `request_or_iterator` object as parameters. The `request_or_iterator` parameter may be type checked to determine if this is a singluar request for unary RPCs or an iterator for client-streaming or client-server streaming RPCs. Args: method: A function that proceeds with the invocation by executing the next interceptor in the chain or invoking the actual RPC on the underlying channel. request_or_iterator: RPC request message or iterator of request messages for streaming requests. call_details: Describes an RPC to be invoked. Returns: The type of the return should match the type of the return value received by calling `method`. This is an object that is both a `Call `_ for the RPC and a `Future `_. The actual result from the RPC can be got by calling `.result()` on the value returned from `method`. """ new_details = ClientCallDetails( call_details.method, call_details.timeout, [("authorization", "Bearer mysecrettoken")], call_details.credentials, call_details.wait_for_ready, call_details.compression, ) return method(request_or_iterator, new_details) ``` Now inject your interceptor when you create the ``grpc`` channel: ```python interceptors = [MetadataClientInterceptor()] with grpc.insecure_channel("grpc-server:50051") as channel: channel = grpc.intercept_channel(channel, *interceptors) ... ``` Client interceptors can also be used to [retry RPCs](https://github.com/d5h-foss/grpc-interceptor/blob/4b6bb6a59aae97aec058c0d4072dd19de8f408bc/tests/test_client.py#L39-L56) that fail due to specific errors, or a host of other use cases. There are some basic approaches in [the tests](https://github.com/d5h-foss/grpc-interceptor/blob/master/tests/test_client.py) to get you started. Note: The `method` in a client interceptor is a `continuation` as described in the [client interceptor section of the gRPC docs](https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryClientInterceptor.intercept_unary_unary). When you invoke the continuation, you get a future back, which resolves to either the result, or exception. This is different than invoking a client stub, which returns the result directly. If the interceptor needs the value returned by the call, or to catch exceptions, then you'll need to do `future = method(request_or_iterator, call_details)`, followed by `future.result()`. Check out the tests for [examples](https://github.com/d5h-foss/grpc-interceptor/blob/4b6bb6a59aae97aec058c0d4072dd19de8f408bc/tests/test_client.py#L39-L56). # Documentation The examples above showed usage for simple unary-unary RPC calls. For examples of streaming and asyncio RPCs, read the [complete documentation here](https://grpc-interceptor.readthedocs.io/). Note that there is no asyncio client interceptors at the moment, though contributions are welcome. ================================================ FILE: docs/conf.py ================================================ """Sphinx configuration.""" import re project = "grpc-interceptor" author = "Dan Hipschman" copyright = f"2020, {author}" extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", ] def setup(app): """Sphinx setup.""" app.connect("autodoc-skip-member", skip_member) def skip_member(app, what, name, obj, skip, options): """Ignore ugly auto-generated doc strings from namedtuple.""" doc = getattr(obj, "__doc__", "") or "" # Handle when __doc__ is missing on None is_namedtuple_docstring = bool(re.fullmatch("Alias for field number [0-9]+", doc)) return is_namedtuple_docstring or skip ================================================ FILE: docs/index.rst ================================================ Simplified Python gRPC Interceptors =================================== .. toctree:: :hidden: :maxdepth: 1 reference license .. contents:: The primary aim of this project is to make Python gRPC interceptors simple. The Python ``grpc`` package provides service interceptors, but they're a bit hard to use because of their flexibility. The ``grpc`` interceptors don't have direct access to the request and response objects, or the service context. Access to these are often desired, to be able to log data in the request or response, or set status codes on the context. The secondary aim of this project is to keep the code small and simple. Code you can read through and understand quickly gives you confidence and helps debug issues. When you install this package, you also don't want a bunch of other packages that might cause conflicts within your project. Too many dependencies slow down installation as well as runtime (fresh imports take time). Hence, a goal of this project is to keep dependencies to a minimum. The only core dependency is the ``grpc`` package, and the ``testing`` extra includes ``protobuf`` as well. The ``grpc_interceptor`` package provides the following: * A ``ServerInterceptor`` base class, to make it easy to define your own server-side interceptors. Do not confuse this with the ``grpc.ServerInterceptor`` class. * An ``AsyncServerInterceptor`` base class, which is the analogy for async server-side interceptors. * An ``ExceptionToStatusInterceptor`` interceptor, so your service can raise exceptions that set the gRPC status code correctly (rather than the default of every exception resulting in an ``UNKNOWN`` status code). This is something for which pretty much any service will have a use. * An ``AsyncExceptionToStatusInterceptor`` interceptor, which is the analogy for async ``ExceptionToStatusInterceptor``. * A ``ClientInterceptor`` base class, to make it easy to define your own client-side interceptors. Do not confuse this with the ``grpc.ClientInterceptor`` class. (Note, there is currently no async analogy to ``ClientInterceptor``, though contributions are welcome.) * An optional testing framework. If you're writing your own interceptors, this is useful. If you're just using ``ExceptionToStatusInterceptor`` then you don't need this. Installation ------------ To install just the interceptors: .. code-block:: console $ pip install grpc-interceptor To also install the testing framework: .. code-block:: console $ pip install grpc-interceptor[testing] Usage ----- Server Interceptors ^^^^^^^^^^^^^^^^^^^ To define your own server interceptor (we can use a simplified version of ``ExceptionToStatusInterceptor`` as an example): .. code-block:: python from grpc_interceptor import ServerInterceptor from grpc_interceptor.exceptions import GrpcException class ExceptionToStatusInterceptor(ServerInterceptor): def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, ) -> Any: """Override this method to implement a custom interceptor. You should call method(request_or_iterator, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: The next interceptor, or method implementation. request_or_iterator: The RPC request, as a protobuf message. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should generally return the result of method(request_or_iterator, context), which is typically the RPC method response, as a protobuf message. The interceptor is free to modify this in some way, however. """ try: return method(request_or_iterator, context) except GrpcException as e: context.set_code(e.status_code) context.set_details(e.details) raise Then inject your interceptor when you create the ``grpc`` server: .. code-block:: python interceptors = [ExceptionToStatusInterceptor()] server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), interceptors=interceptors ) To use ``ExceptionToStatusInterceptor``: .. code-block:: python from grpc_interceptor.exceptions import NotFound class MyService(my_pb2_grpc.MyServiceServicer): def MyRpcMethod( self, request: MyRequest, context: grpc.ServicerContext ) -> MyResponse: thing = lookup_thing() if not thing: raise NotFound("Sorry, your thing is missing") ... This results in the gRPC status status code being set to ``NOT_FOUND``, and the details ``"Sorry, your thing is missing"``. This saves you the hassle of catching exceptions in your service handler, or passing the context down into helper functions so they can call ``context.abort`` or ``context.set_code``. It allows the more Pythonic approach of just raising an exception from anywhere in the code, and having it be handled automatically. Server Streaming Interceptors """"""""""""""""""""""""""""" The above example shows how to write an interceptor for a unary-unary RPC. Server streaming RPCs need to be handled a little differently because ``method(request, context)`` will return a generator. Hence, the code won't actually run until you iterate over it. Hence, if we were to continue the example of catching exceptions from RPCs, we would need to do something like this: .. code-block:: python class ExceptionToStatusInterceptor(ServerInterceptor): def intercept( self, method: Callable, request: Any, context: grpc.ServicerContext, method_name: str, ) -> Any: try: for response in method(request, context): yield response except GrpcException as e: context.set_code(e.status_code) context.set_details(e.details) raise However, this will *only* work for server streaming RPCs. In order to work with both unary and streaming RPCs, you'll need to handle the unary case and streaming case separately, like this: .. code-block:: python class ExceptionToStatusInterceptor(ServerInterceptor): def intercept(self, method, request, context, method_name): # Call the RPC. It could be either unary or streaming try: response_or_iterator = method(request, context) except GrpcException as e: # If it was unary, then any exception raised would be caught # immediately, so handle it here. context.set_code(e.status_code) context.set_details(e.details) raise # Check if it's streaming if hasattr(response_or_iterator, "__iter__"): # Now we know it's a server streaming RPC, so the actual RPC method # hasn't run yet. Delegate to a helper to iterate over it so it runs. # The helper needs to re-yield the responses, and we need to return # the generator that produces. return self._intercept_streaming(response_or_iterator) else: # For unary cases, we are done, so just return the response. return response_or_iterator def _intercept_streaming(self, iterator): try: for resp in iterator: yield resp except GrpcException as e: context.set_code(e.status_code) context.set_details(e.details) raise Async Server Interceptors """"""""""""""""""""""""" Async interceptors are similar to sync ones, but there are two things of which you need to be aware. First, async server streaming RPCs that are implemented with ``async def + yield`` cannot be awaited. When you call such a method, you get back an ``async_generator``. This is not ``await``-able (though you can ``async for`` loop over it). This is contrary to a unary RPC is implemented with ``async def + return``. That results in a coroutine when called, which you *can* ``await``. All this is to say that you mustn't await ``method(request, context)`` in an async interceptor immediately. First, check if it's an ``async_generator``. You can do this by checking for the presence of the ``__aiter__`` attribute. Here's an async version of our running ``ExceptionToStatusInterceptor`` example: .. code-block:: python from grpc_interceptor.exceptions import GrpcException from grpc_interceptor.server import AsyncServerInterceptor class AsyncExceptionToStatusInterceptor(AsyncServerInterceptor): async def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, ) -> Any: try: response_or_iterator = method(request_or_iterator, context) if not hasattr(response_or_iterator, "__aiter__"): # Unary, just await and return the response return await response_or_iterator except GrpcException as e: await context.set_code(e.status_code) await context.set_details(e.details) raise # Server streaming responses, delegate to an async generator helper. # Note that we do NOT await this. return self._intercept_streaming(response_or_iterator, context) async def _intercept_streaming(self, iterator, context): try: async for r in iterator: yield r except GrpcException as e: await context.set_code(e.status_code) await context.set_details(e.details) raise The second thing you must be aware of with async RPCs, is that an `alternate streaming API `_ was added. With this API, instead of writing a server streaming RPC with ``async def + yield``, you write it as ``async def + return``, but it returns ``None``. The way it streams responses is by calling ``await context.write(...)`` for each response it streams. Similarly, client streaming can be achieved by calling ``await context.read()`` instead of iterating over the request object. If you must support RPC services written using this new API, then you must be aware that a server streaming RPC could return ``None``. In that case it will not be an ``async_generator`` even though it's streaming. You will also need your own solution to get access to the streaming response objects. For example, you could wrap the ``context`` object that you pass to ``method(request, context)``, so that you can capture ``read`` and ``write`` calls. Client Interceptors ^^^^^^^^^^^^^^^^^^^ We will use an invocation metadata injecting interceptor as an example of defining a client interceptor: .. code-block:: python from grpc_interceptor import ClientCallDetails, ClientInterceptor class MetadataClientInterceptor(ClientInterceptor): def intercept( self, method: Callable, request_or_iterator: Any, call_details: grpc.ClientCallDetails, ): """Override this method to implement a custom interceptor. This method is called for all unary and streaming RPCs. The interceptor implementation should call `method` using a `grpc.ClientCallDetails` and the `request_or_iterator` object as parameters. The `request_or_iterator` parameter may be type checked to determine if this is a singluar request for unary RPCs or an iterator for client-streaming or client-server streaming RPCs. Args: method: A function that proceeds with the invocation by executing the next interceptor in the chain or invoking the actual RPC on the underlying channel. request_or_iterator: RPC request message or iterator of request messages for streaming requests. call_details: Describes an RPC to be invoked. Returns: The type of the return should match the type of the return value received by calling `method`. This is an object that is both a `Call `_ for the RPC and a `Future `_. The actual result from the RPC can be got by calling `.result()` on the value returned from `method`. """ new_details = ClientCallDetails( call_details.method, call_details.timeout, [("authorization", "Bearer mysecrettoken")], call_details.credentials, call_details.wait_for_ready, call_details.compression, ) return method(request_or_iterator, new_details) Now inject your interceptor when you create the ``grpc`` channel: .. code-block:: python interceptors = [MetadataClientInterceptor()] with grpc.insecure_channel("grpc-server:50051") as channel: channel = grpc.intercept_channel(channel, *interceptors) ... Client interceptors can also be used to `retry RPCs `_ that fail due to specific errors, or a host of other use cases. There are some basic approaches in `the tests `_ to get you started. Note: The ``method`` in a client interceptor is a ``continuation`` as described in the `client interceptor section of the gRPC docs `_. When you invoke the continuation, you get a future back, which resolves to either the result, or exception. This is different than invoking a client stub, which returns the result directly. If the interceptor needs the value returned by the call, or to catch exceptions, then you'll need to do ``future = method(request_or_iterator, call_details)``, followed by ``future.result()``. Check out the tests for `examples `_. Testing ------- The testing framework provides an actual gRPC service and client, which you can inject interceptors into. This allows end-to-end testing, rather than mocking things out (such as the context). This can catch interactions between your interceptors and the gRPC framework, and also allows chaining interceptors. The crux of the testing framework is the ``dummy_client`` context manager. It provides a client to a gRPC service, which by defaults echos the ``input`` field of the request to the ``output`` field of the response. You can also provide a ``special_cases`` dict which tells the service to call arbitrary functions when the input matches a key in the dict. This allows you to test things like exceptions being thrown. Here's an example (again using ``ExceptionToStatusInterceptor``): .. code-block:: python from grpc_interceptor import ExceptionToStatusInterceptor from grpc_interceptor.exceptions import NotFound from grpc_interceptor.testing import dummy_client, DummyRequest, raises def test_exception(): special_cases = {"error": raises(NotFound())} interceptors = [ExceptionToStatusInterceptor()] with dummy_client(special_cases=special_cases, interceptors=interceptors) as client: # Test a happy path first assert client.Execute(DummyRequest(input="foo")).output == "foo" # And now a special case with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == grpc.StatusCode.NOT_FOUND Limitations ----------- Known limitations: * Async client interceptors are not implemented. * The ``read`` / ``write`` API for async streaming technically works, but you'll need to roll your own solution to get access to streaming request and response objects. Contributions or requests are welcome for any limitations you may find. ================================================ FILE: docs/license.rst ================================================ License ======= .. include:: ../LICENSE ================================================ FILE: docs/reference.rst ================================================ Reference ========= .. contents:: :local: :backlinks: none grpc_interceptor --------------------- .. automodule:: grpc_interceptor :members: grpc_interceptor.exceptions ------------------------------------ .. automodule:: grpc_interceptor.exceptions :members: grpc_interceptor.testing ------------------------------------ .. automodule:: grpc_interceptor.testing :members: ================================================ FILE: docs/requirements.txt ================================================ sphinx==3.1.2 docutils==0.16 ================================================ FILE: mypy.ini ================================================ [mypy] [mypy-nox.*,grpc,pytest] ignore_missing_imports = True ================================================ FILE: noxfile.py ================================================ """Nox sessions.""" from contextlib import contextmanager from pathlib import Path import tempfile from typing import List from uuid import uuid4 import nox import toml nox.options.sessions = "lint", "mypy", "tests", "xdoctest", "mindeps" PY_VERSIONS = ["3.11", "3.10", "3.9", "3.8", "3.7"] PY_LATEST = "3.11" @nox.session(python=PY_VERSIONS) def tests(session): """Run the test suite.""" args = session.posargs or [ "--cov", "--cov-report", "term-missing", "--cov-report", "xml", "-ra", "-vv", ] session.run("poetry", "install", "--no-dev", external=True) install_with_constraints( session, "coverage[toml]", "grpcio-tools", "pytest", "pytest-asyncio", "pytest-cov", ) session.run("pytest", *args) @nox.session(python=PY_VERSIONS) def xdoctest(session) -> None: """Run examples with xdoctest.""" args = session.posargs or ["all"] session.run("poetry", "install", "--no-dev", external=True) install_with_constraints(session, "xdoctest") session.run("python", "-m", "xdoctest", "grpc_interceptor", *args) @nox.session(python=PY_LATEST) def docs(session): """Build the documentation.""" session.run("poetry", "install", "--no-dev", "-E", "testing", external=True) install_with_constraints(session, "sphinx") session.run("sphinx-build", "docs", "docs/_build") SOURCE_CODE = ["src", "tests", "noxfile.py", "docs/conf.py"] @nox.session(python=PY_LATEST) def black(session): """Run black code formatter.""" args = session.posargs or SOURCE_CODE install_with_constraints(session, "black") session.run("black", *args) @nox.session(python=PY_LATEST) def lint(session): """Lint using flake8.""" args = session.posargs or SOURCE_CODE install_with_constraints( session, "flake8", "flake8-bandit", "flake8-bugbear", "flake8-docstrings", "flake8-import-order", ) session.run("flake8", *args) @nox.session(python=PY_VERSIONS) def mypy(session): """Type-check using mypy.""" args = session.posargs or SOURCE_CODE install_with_constraints(session, "mypy") session.run("mypy", "--install-types", "--non-interactive", *args) session.run("mypy", *args) @nox.session(python=PY_LATEST) def safety(session): """Scan dependencies for insecure packages.""" with _temp_file() as requirements: session.run( "poetry", "export", "--dev", "--format=requirements.txt", "--without-hashes", f"--output={requirements}", external=True, ) install_with_constraints(session, "safety") session.run("safety", "check", f"--file={requirements}", "--full-report") @nox.session(python="3.7") def mindeps(session): """Run test with minimum versions of dependencies.""" deps = _parse_minimum_dependency_versions() session.install(*deps) session.run("pytest", env={"PYTHONPATH": "src"}) def install_with_constraints(session, *args, **kwargs): """Install packages constrained by Poetry's lock file.""" with _temp_file() as requirements: session.run( "poetry", "export", "--dev", "--format=requirements.txt", f"--output={requirements}", "--without-hashes", external=True, ) session.install(f"--constraint={requirements}", *args, **kwargs) @contextmanager def _temp_file(): # NamedTemporaryFile doesn't work on Windows. path = Path(tempfile.gettempdir()) / str(uuid4()) try: yield path finally: try: path.unlink() except FileNotFoundError: pass def _parse_minimum_dependency_versions() -> List[str]: pyproj = toml.load("pyproject.toml") dependencies = pyproj["tool"]["poetry"]["dependencies"] dev_dependencies = pyproj["tool"]["poetry"]["dev-dependencies"] min_deps = [] for deps in (dependencies, dev_dependencies): for dep, constraint in deps.items(): if dep == "python": continue if not isinstance(constraint, str): # Don't install deps with python contraints, because they're always for # newer versions on python. if "python" in constraint: continue constraint = constraint["version"] if constraint.startswith("^") or constraint.startswith("~"): version = constraint[1:] elif constraint.startswith(">="): version = constraint[2:] else: version = constraint min_deps.append(f"{dep}=={version}") return min_deps ================================================ FILE: pyproject.toml ================================================ [tool.poetry] name = "grpc-interceptor" version = "0.15.4" description = "Simplifies gRPC interceptors" license = "MIT" readme = "README.md" homepage = "https://github.com/d5h-foss/grpc-interceptor" repository = "https://github.com/d5h-foss/grpc-interceptor" keywords = ["grpc", "interceptor"] authors = ["Dan Hipschman"] documentation = "https://grpc-interceptor.readthedocs.io" [tool.poetry.dependencies] python = "^3.7" # Earliest version that supports aio w/o deadlock issue: https://github.com/grpc/grpc/pull/23945 grpcio = "^1.49.1" # https://github.com/protocolbuffers/protobuf/issues/10075 protobuf = {version = ">=4.21.9", optional = true} [tool.poetry.extras] testing = ["protobuf"] [tool.poetry.dev-dependencies] pytest = "^6.1.0" grpcio-tools = "^1.49.1" coverage = {extras = ["toml"], version = "^7.2.3"} pytest-cov = "^2.10.0" black = "^23.3.0" flake8 = "^5.0.0" flake8-bandit = "^4.1.1" # https://github.com/tylerwince/flake8-bandit/issues/21 flake8-bugbear = "^20.1.4" flake8-import-order = "^0.18.1" safety = "^1.9.0" mypy = "^1.2.0" mypy-protobuf = "^1.23" flake8-docstrings = "^1.5.0" sphinx = "^3.1.2" xdoctest = "^0.13.0" pytest-asyncio = {version = "^0.19.0", python = ">=3.7"} [tool.coverage.paths] source = ["src"] [tool.coverage.run] branch = true source = ["grpc_interceptor"] omit = ["*_pb2.py", "*_pb2_grpc.py"] [tool.coverage.report] show_missing = true [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" ================================================ FILE: src/grpc_interceptor/__init__.py ================================================ """Simplified Python gRPC interceptors.""" from grpc_interceptor.client import ClientCallDetails, ClientInterceptor from grpc_interceptor.exception_to_status import ( AsyncExceptionToStatusInterceptor, ExceptionToStatusInterceptor, ) from grpc_interceptor.server import ( AsyncServerInterceptor, MethodName, parse_method_name, ServerInterceptor, ) __all__ = [ "AsyncExceptionToStatusInterceptor", "AsyncServerInterceptor", "ClientCallDetails", "ClientInterceptor", "ExceptionToStatusInterceptor", "MethodName", "parse_method_name", "ServerInterceptor", ] ================================================ FILE: src/grpc_interceptor/client.py ================================================ """Base class for client-side interceptors.""" import abc from typing import Any, Callable, Iterator, NamedTuple, Optional, Sequence, Tuple, Union import grpc class _ClientCallDetailsFields(NamedTuple): method: str timeout: Optional[float] metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] credentials: Optional[grpc.CallCredentials] wait_for_ready: Optional[bool] compression: Any # Type added in grpcio 1.23.0 class ClientCallDetails(_ClientCallDetailsFields, grpc.ClientCallDetails): """Describes an RPC to be invoked. See https://grpc.github.io/grpc/python/grpc.html#grpc.ClientCallDetails """ pass class ClientInterceptorReturnType(grpc.Call, grpc.Future): """Return type for the ClientInterceptor.intercept method.""" pass class ClientInterceptor( grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor, metaclass=abc.ABCMeta, ): """Base class for client-side interceptors. To implement an interceptor, subclass this class and override the intercept method. """ @abc.abstractmethod def intercept( self, method: Callable, request_or_iterator: Any, call_details: grpc.ClientCallDetails, ) -> ClientInterceptorReturnType: """Override this method to implement a custom interceptor. This method is called for all unary and streaming RPCs. The interceptor implementation should call `method` using a `grpc.ClientCallDetails` and the `request_or_iterator` object as parameters. The `request_or_iterator` parameter may be type checked to determine if this is a singluar request for unary RPCs or an iterator for client-streaming or client-server streaming RPCs. Args: method: A function that proceeds with the invocation by executing the next interceptor in the chain or invoking the actual RPC on the underlying channel. request_or_iterator: RPC request message or iterator of request messages for streaming requests. call_details: Describes an RPC to be invoked. Returns: The type of the return should match the type of the return value received by calling `method`. This is an object that is both a `Call `_ for the RPC and a `Future `_. The actual result from the RPC can be got by calling `.result()` on the value returned from `method`. """ return method(request_or_iterator, call_details) # pragma: no cover def intercept_unary_unary( self, continuation: Callable, call_details: grpc.ClientCallDetails, request: Any, ): """Implementation of grpc.UnaryUnaryClientInterceptor. This is not part of the grpc_interceptor.ClientInterceptor API, but must have a public name. Do not override it, unless you know what you're doing. """ return self.intercept(_swap_args(continuation), request, call_details) def intercept_unary_stream( self, continuation: Callable, call_details: grpc.ClientCallDetails, request: Any, ): """Implementation of grpc.UnaryStreamClientInterceptor. This is not part of the grpc_interceptor.ClientInterceptor API, but must have a public name. Do not override it, unless you know what you're doing. """ return self.intercept(_swap_args(continuation), request, call_details) def intercept_stream_unary( self, continuation: Callable, call_details: grpc.ClientCallDetails, request_iterator: Iterator[Any], ): """Implementation of grpc.StreamUnaryClientInterceptor. This is not part of the grpc_interceptor.ClientInterceptor API, but must have a public name. Do not override it, unless you know what you're doing. """ return self.intercept(_swap_args(continuation), request_iterator, call_details) def intercept_stream_stream( self, continuation: Callable, call_details: grpc.ClientCallDetails, request_iterator: Iterator[Any], ): """Implementation of grpc.StreamStreamClientInterceptor. This is not part of the grpc_interceptor.ClientInterceptor API, but must have a public name. Do not override it, unless you know what you're doing. """ return self.intercept(_swap_args(continuation), request_iterator, call_details) def _swap_args(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]: def new_fn(x, y): return fn(y, x) return new_fn ================================================ FILE: src/grpc_interceptor/exception_to_status.py ================================================ """ExceptionToStatusInterceptor catches GrpcException and sets the gRPC context.""" # TODO: use asynccontextmanager from contextlib import contextmanager from typing import ( Any, AsyncGenerator, AsyncIterable, Callable, Generator, Iterable, Iterator, NoReturn, Optional, ) import grpc from grpc import aio as grpc_aio from grpc_interceptor.exceptions import GrpcException from grpc_interceptor.server import AsyncServerInterceptor, ServerInterceptor class ExceptionToStatusInterceptor(ServerInterceptor): """An interceptor that catches exceptions and sets the RPC status and details. ExceptionToStatusInterceptor will catch any subclass of GrpcException and set the status code and details on the gRPC context. You can also extend this and override the handle_exception method to catch other types of exceptions, and handle them in different ways. E.g., you can catch and handle exceptions that don't derive from GrpcException. Or you can set rich error statuses with context.abort_with_status(). Args: status_on_unknown_exception: Specify what to do if an exception which is not a subclass of GrpcException is raised. If None, do nothing (by default, grpc will set the status to UNKNOWN). If not None, then the status code will be set to this value if `context.abort` hasn't been called earlier. It must not be OK. The details will be set to the value of repr(e), where e is the exception. In any case, the exception will be propagated. Raises: ValueError: If status_code is OK. """ def __init__(self, status_on_unknown_exception: Optional[grpc.StatusCode] = None): if status_on_unknown_exception == grpc.StatusCode.OK: raise ValueError("The status code for unknown exceptions cannot be OK") self._status_on_unknown_exception = status_on_unknown_exception def _generate_responses( self, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, response_iterator: Iterable, ) -> Generator[Any, None, None]: """Yield all the responses, but check for errors along the way.""" with self._handle_exception(request_or_iterator, context, method_name): yield from response_iterator @contextmanager def _handle_exception( self, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str ) -> Iterator[None]: try: yield except Exception as ex: self.handle_exception(ex, request_or_iterator, context, method_name) def handle_exception( self, ex: Exception, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, ) -> NoReturn: """Override this if extending ExceptionToStatusInterceptor. This will get called when an exception is raised while handling the RPC. Args: ex: The exception that was raised. request_or_iterator: The RPC request, as a protobuf message if it is a unary request, or an iterator of protobuf messages if it is a streaming request. context: The servicer context. You probably want to call context.abort(...) method_name: The name of the RPC being called. Raises: This method must raise and cannot return, as in general there's no meaningful RPC response to return if an exception has occurred. You can raise the original exception, ex, or something else. """ if isinstance(ex, GrpcException): context.abort(ex.status_code, ex.details) elif not context.code(): if self._status_on_unknown_exception is not None: context.abort(self._status_on_unknown_exception, repr(ex)) raise ex def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, ) -> Any: """Do not call this directly; use the interceptor kwarg on grpc.server().""" with self._handle_exception(request_or_iterator, context, method_name): response_or_iterator = method(request_or_iterator, context) if isinstance(response_or_iterator, Iterable): # multiple responses; return a generator return self._generate_responses( request_or_iterator, context, method_name, response_or_iterator ) else: # return a single response return response_or_iterator class AsyncExceptionToStatusInterceptor(AsyncServerInterceptor): """An interceptor that catches exceptions and sets the RPC status and details. This is the async analogy to ExceptionToStatusInterceptor. Please see that class' documentation for more information. """ def __init__(self, status_on_unknown_exception: Optional[grpc.StatusCode] = None): if status_on_unknown_exception == grpc.StatusCode.OK: raise ValueError("The status code for unknown exceptions cannot be OK") self._status_on_unknown_exception = status_on_unknown_exception async def _generate_responses( self, request_or_iterator: Any, context: grpc_aio.ServicerContext, method_name: str, response_iterator: AsyncIterable, ) -> AsyncGenerator[Any, None]: """Yield all the responses, but check for errors along the way.""" try: async for r in response_iterator: yield r except Exception as ex: await self.handle_exception(ex, request_or_iterator, context, method_name) async def handle_exception( self, ex: Exception, request_or_iterator: Any, context: grpc_aio.ServicerContext, method_name: str, ) -> NoReturn: """Override this if extending ExceptionToStatusInterceptor. This will get called when an exception is raised while handling the RPC. Args: ex: The exception that was raised. request_or_iterator: The RPC request, as a protobuf message if it is a unary request, or an iterator of protobuf messages if it is a streaming request. context: The servicer context. You probably want to call context.abort(...) method_name: The name of the RPC being called. Raises: This method must raise and cannot return, as in general there's no meaningful RPC response to return if an exception has occurred. You can raise the original exception, ex, or something else. """ if isinstance(ex, GrpcException): await context.abort(ex.status_code, ex.details) elif not context.code(): if self._status_on_unknown_exception is not None: await context.abort(self._status_on_unknown_exception, repr(ex)) raise ex async def intercept( self, method: Callable, request_or_iterator: Any, context: grpc_aio.ServicerContext, method_name: str, ) -> Any: """Do not call this directly; use the interceptor kwarg on grpc.server().""" try: response_or_iterator = method(request_or_iterator, context) if not hasattr(response_or_iterator, "__aiter__"): return await response_or_iterator except Exception as ex: await self.handle_exception(ex, request_or_iterator, context, method_name) return self._generate_responses( request_or_iterator, context, method_name, response_or_iterator ) ================================================ FILE: src/grpc_interceptor/exceptions.py ================================================ """Exceptions for ExceptionToStatusInterceptor. See https://grpc.github.io/grpc/core/md_doc_statuscodes.html for the source of truth on status code meanings. """ from typing import Optional from grpc import StatusCode class GrpcException(Exception): """Base class for gRPC exceptions. Generally you would not use this class directly, but rather use a subclass representing one of the standard gRPC status codes (see: https://grpc.github.io/grpc/core/md_doc_statuscodes.html for the official list). Attributes: status_code: A grpc.StatusCode other than OK. The only use case for this is if gRPC adds a new status code that isn't represented by one of the subclasses of GrpcException. Must not be OK, because gRPC will not raise an RpcError to the client if the status code is OK. details: A string with additional informantion about the error. Args: details: If not None, specifies a custom error message. status_code: If not None, sets the status code. Raises: ValueError: If status_code is OK. """ status_code: StatusCode = StatusCode.UNKNOWN details: str = "Unknown exception occurred" def __init__( self, details: Optional[str] = None, status_code: Optional[StatusCode] = None ): if status_code is not None: if status_code == StatusCode.OK: raise ValueError("The status code for an exception cannot be OK") self.status_code = status_code if details is not None: self.details = details def __repr__(self) -> str: """Show the status code and details. Returns: A string displaying the class name, status code, and details. """ clsname = self.__class__.__name__ sc = self.status_code.name return f"{clsname}(status_code={sc}, details={self.details!r})" @property def status_string(self): """Return status_code as a string. Returns: The status code as a string. Example: >>> GrpcException(status_code=StatusCode.NOT_FOUND).status_string 'NOT_FOUND' """ return self.status_code.name class Aborted(GrpcException): """The operation was aborted. Typically this is due to a concurrency issue such as a sequencer check failure or transaction abort. See the guidelines on other exceptions for deciding between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE. """ status_code = StatusCode.ABORTED details = "The operation was aborted" class AlreadyExists(GrpcException): """The entity that a client attempted to create already exists. E.g., a file or directory that a client is trying to create already exists. """ status_code = StatusCode.ALREADY_EXISTS details = "The entity attempted to be created already exists" class Cancelled(GrpcException): """The operation was cancelled, typically by the caller.""" status_code = StatusCode.CANCELLED details = "The operation was cancelled" class DataLoss(GrpcException): """Unrecoverable data loss or corruption.""" status_code = StatusCode.DATA_LOSS details = "There was unrecoverable data loss or corruption" class DeadlineExceeded(GrpcException): """The deadline expired before the operation could complete. For operations that change the state of the system, this error may be returned even if the operation has completed successfully. For example, a successful response from a server could have been delayed long. """ status_code = StatusCode.DEADLINE_EXCEEDED details = "Deadline expired before operation could complete" class FailedPrecondition(GrpcException): """The operation failed because the system is in an invalid state for execution. For example, the directory to be deleted is non-empty, an rmdir operation is applied to a non-directory, etc. Service implementors can use the following guidelines to decide between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: (a) Use UNAVAILABLE if the client can retry just the failing call. (b) Use ABORTED if the client should retry at a higher level (e.g., when a client-specified test-and-set fails, indicating the client should restart a read-modify-write sequence). (c) Use FAILED_PRECONDITION if the client should not retry until the system state has been explicitly fixed. E.g., if an "rmdir" fails because the directory is non-empty, FAILED_PRECONDITION should be returned since the client should not retry unless the files are deleted from the directory. """ status_code = StatusCode.FAILED_PRECONDITION details = ( "The operation was rejected because the system is not" " in a state required for execution" ) class InvalidArgument(GrpcException): """The client specified an invalid argument. Note that this differs from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments that are problematic regardless of the state of the system (e.g., a malformed file name). """ status_code = StatusCode.INVALID_ARGUMENT details = "The client specified an invalid argument" class Internal(GrpcException): """Internal errors. This means that some invariants expected by the underlying system have been broken. This error code is reserved for serious errors. """ status_code = StatusCode.INTERNAL details = "Internal error" class OutOfRange(GrpcException): """The operation was attempted past the valid range. E.g., seeking or reading past end-of-file. Unlike INVALID_ARGUMENT, this error indicates a problem that may be fixed if the system state changes. For example, a 32-bit file system will generate INVALID_ARGUMENT if asked to read at an offset that is not in the range [0,2^32-1], but it will generate OUT_OF_RANGE if asked to read from an offset past the current file size. There is a fair bit of overlap between FAILED_PRECONDITION and OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific error) when it applies so that callers who are iterating through a space can easily look for an OUT_OF_RANGE error to detect when they are done. """ status_code = StatusCode.OUT_OF_RANGE details = "The operation was attempted past the valid range" class NotFound(GrpcException): """Some requested entity (e.g., file or directory) was not found. Note to server developers: if a request is denied for an entire class of users, such as gradual feature rollout or undocumented whitelist, NOT_FOUND may be used. If a request is denied for some users within a class of users, such as user-based access control, PERMISSION_DENIED must be used. """ status_code = StatusCode.NOT_FOUND details = "The requested entity was not found" class PermissionDenied(GrpcException): """The caller does not have permission to execute the specified operation. PERMISSION_DENIED must not be used for rejections caused by exhausting some resource (use RESOURCE_EXHAUSTED instead for those errors). PERMISSION_DENIED must not be used if the caller can not be identified (use UNAUTHENTICATED instead for those errors). This error code does not imply the request is valid or the requested entity exists or satisfies other pre-conditions. """ status_code = StatusCode.PERMISSION_DENIED details = "The caller does not have permission to execute the specified operation" class ResourceExhausted(GrpcException): """Some resource has been exhausted. Perhaps a per-user quota, or perhaps the entire file system is out of space. """ status_code = StatusCode.RESOURCE_EXHAUSTED details = "A resource has been exhausted" class Unauthenticated(GrpcException): """The request does not have valid authentication credentials for the operation.""" status_code = StatusCode.UNAUTHENTICATED details = ( "The request does not have valid authentication credentials for the operation" ) class Unavailable(GrpcException): """The service is currently unavailable. This is most likely a transient condition, which can be corrected by retrying with a backoff. Note that it is not always safe to retry non-idempotent operations. """ status_code = StatusCode.UNAVAILABLE details = "The service is currently unavailable" class Unimplemented(GrpcException): """The operation is not implemented or is not supported/enabled in this service.""" status_code = StatusCode.UNIMPLEMENTED details = ( "The operation is not implemented or not supported/enabled in this service" ) class Unknown(GrpcException): """Unknown error. For example, this error may be returned when a Status value received from another address space belongs to an error space that is not known in this address space. Also errors raised by APIs that do not return enough error information may be converted to this error. """ pass ================================================ FILE: src/grpc_interceptor/py.typed ================================================ ================================================ FILE: src/grpc_interceptor/server.py ================================================ """Base class for server-side interceptors.""" import abc from asyncio import iscoroutine from typing import Any, Callable, Tuple import grpc from grpc import aio as grpc_aio # Needed for grpcio pre-1.33.2 class ServerInterceptor(grpc.ServerInterceptor, metaclass=abc.ABCMeta): """Base class for server-side interceptors. To implement an interceptor, subclass this class and override the intercept method. """ @abc.abstractmethod def intercept( self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str, ) -> Any: # pragma: no cover """Override this method to implement a custom interceptor. You should call method(request_or_iterator, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: Either the RPC method implementation, or the next interceptor in the chain. request_or_iterator: The RPC request, as a protobuf message if it is a unary request, or an iterator of protobuf messages if it is a streaming request. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should return the result of method(request, context), which is typically the RPC method response, as a protobuf message, or an iterator of protobuf messages for streaming responses. The interceptor is free to modify this in some way, however. """ return method(request_or_iterator, context) # Implementation of grpc.ServerInterceptor, do not override. def intercept_service(self, continuation, handler_call_details): """Implementation of grpc.ServerInterceptor. This is not part of the grpc_interceptor.ServerInterceptor API, but must have a public name. Do not override it, unless you know what you're doing. """ next_handler = continuation(handler_call_details) # Returns None if the method isn't implemented. if next_handler is None: return handler_factory, next_handler_method = _get_factory_and_method(next_handler) def invoke_intercept_method(request_or_iterator, context): method_name = handler_call_details.method return self.intercept( next_handler_method, request_or_iterator, context, method_name, ) return handler_factory( invoke_intercept_method, request_deserializer=next_handler.request_deserializer, response_serializer=next_handler.response_serializer, ) class AsyncServerInterceptor(grpc_aio.ServerInterceptor, metaclass=abc.ABCMeta): """Base class for asyncio server-side interceptors. To implement an interceptor, subclass this class and override the intercept method. """ @abc.abstractmethod async def intercept( self, method: Callable, request_or_iterator: Any, context: grpc_aio.ServicerContext, method_name: str, ) -> Any: # pragma: no cover """Override this method to implement a custom interceptor. You should await method(request_or_iterator, context) to invoke the next handler (either the RPC method implementation, or the next interceptor in the list). Args: method: Either the RPC method implementation, or the next interceptor in the chain. request_or_iterator: The RPC request, as a protobuf message if it is a unary request, or an iterator of protobuf messages if it is a streaming request. context: The ServicerContext pass by gRPC to the service. method_name: A string of the form "/protobuf.package.Service/Method" Returns: This should return the result of method(request_or_iterator, context), which is typically the RPC method response, as a protobuf message. The interceptor is free to modify this in some way, however. """ response_or_iterator = method(request_or_iterator, context) if hasattr(response_or_iterator, "__aiter__"): return response_or_iterator else: return await response_or_iterator # Implementation of grpc.ServerInterceptor, do not override. async def intercept_service(self, continuation, handler_call_details): """Implementation of grpc.aio.ServerInterceptor. This is not part of the grpc_interceptor.AsyncServerInterceptor API, but must have a public name. Do not override it, unless you know what you're doing. """ next_handler = await continuation(handler_call_details) # Returns None if the method isn't implemented. if not next_handler: return handler_factory, next_handler_method = _get_factory_and_method(next_handler) if next_handler.response_streaming: async def invoke_intercept_method(request, context): method_name = handler_call_details.method coroutine_or_asyncgen = self.intercept( next_handler_method, request, context, method_name, ) # Async server streaming handlers return async_generator, because they # use the async def + yield syntax. However, this is NOT a coroutine # and hence is not awaitable. This can be a problem if the interceptor # ignores the individual streaming response items and simply returns the # result of method(request, context). In that case the interceptor IS a # coroutine, and hence should be awaited. In both cases, we need # something we can iterate over so that THIS function is an # async_generator like the actual RPC method. if iscoroutine(coroutine_or_asyncgen): asyncgen_or_none = await coroutine_or_asyncgen # If a handler is using the read/write API, it will return None. if not asyncgen_or_none: return asyncgen = asyncgen_or_none else: asyncgen = coroutine_or_asyncgen async for r in asyncgen: yield r else: async def invoke_intercept_method(request, context): method_name = handler_call_details.method return await self.intercept( next_handler_method, request, context, method_name, ) return handler_factory( invoke_intercept_method, request_deserializer=next_handler.request_deserializer, response_serializer=next_handler.response_serializer, ) def _get_factory_and_method( rpc_handler: grpc.RpcMethodHandler, ) -> Tuple[Callable, Callable]: if rpc_handler.unary_unary: return grpc.unary_unary_rpc_method_handler, rpc_handler.unary_unary elif rpc_handler.unary_stream: return grpc.unary_stream_rpc_method_handler, rpc_handler.unary_stream elif rpc_handler.stream_unary: return grpc.stream_unary_rpc_method_handler, rpc_handler.stream_unary elif rpc_handler.stream_stream: return grpc.stream_stream_rpc_method_handler, rpc_handler.stream_stream else: # pragma: no cover raise RuntimeError("RPC handler implementation does not exist") class MethodName: """Represents a gRPC method name. gRPC methods are defined by three parts, represented by the three attributes. Attributes: package: This is defined by the `package foo.bar;` designation in the protocol buffer definition, or it could be defined by the protocol buffer directory structure, depending on the language (see https://developers.google.com/protocol-buffers/docs/proto3#packages). service: This is the service name in the protocol buffer definition (e.g., `service SearchService { ... }`. method: This is the method name. (e.g., `rpc Search(...) returns (...);`). """ def __init__(self, package: str, service: str, method: str): self.package = package self.service = service self.method = method def __repr__(self) -> str: """Object-like representation.""" return ( f"MethodName(package='{self.package}', service='{self.service}'," f" method='{self.method}')" ) @property def fully_qualified_service(self): """Return the service name prefixed with the package. Example: >>> MethodName("foo.bar", "SearchService", "Search").fully_qualified_service 'foo.bar.SearchService' """ return f"{self.package}.{self.service}" if self.package else self.service def parse_method_name(method_name: str) -> MethodName: """Parse a method name into package, service and endpoint components. Arguments: method_name: A string of the form "/foo.bar.SearchService/Search", as passed to ServerInterceptor.intercept(). Returns: A MethodName object. Example: >>> parse_method_name("/foo.bar.SearchService/Search") MethodName(package='foo.bar', service='SearchService', method='Search') """ _, package_and_service, method = method_name.split("/") *maybe_package, service = package_and_service.rsplit(".", maxsplit=1) package = maybe_package[0] if maybe_package else "" return MethodName(package, service, method) ================================================ FILE: src/grpc_interceptor/testing/__init__.py ================================================ """A framework for testing interceptors.""" from typing import Callable from grpc_interceptor.testing.dummy_client import ( dummy_client, DummyService, ) from grpc_interceptor.testing.protos.dummy_pb2 import DummyRequest, DummyResponse __all__ = [ "dummy_client", "DummyRequest", "DummyResponse", "DummyService", "raises", ] def raises(e: Exception) -> Callable: """Return a function that raises the given exception when called. Args: e: The exception to be raised. Returns: A function that can take any arguments, and raises the given exception. """ def f(*args, **kwargs): raise (e) return f ================================================ FILE: src/grpc_interceptor/testing/dummy_client.py ================================================ """Defines a service and client for testing interceptors.""" import asyncio from concurrent import futures from contextlib import contextmanager from inspect import iscoroutine from threading import Event, Thread from typing import ( Any, AsyncGenerator, AsyncIterable, Callable, Dict, Iterable, List, Optional, Union ) import grpc from grpc_interceptor.client import ClientInterceptor from grpc_interceptor.server import AsyncServerInterceptor, grpc_aio, ServerInterceptor from grpc_interceptor.testing.protos import dummy_pb2_grpc from grpc_interceptor.testing.protos.dummy_pb2 import DummyRequest, DummyResponse SpecialCaseFunction = Callable[ [str, Union[grpc.ServicerContext, grpc_aio.ServicerContext]], str ] class _SpecialCaseMixin: _special_cases: Dict[str, SpecialCaseFunction] def _get_output(self, request: DummyRequest, context: grpc.ServicerContext) -> str: input = request.input output = input if input in self._special_cases: output = self._special_cases[input](input, context) return output async def _get_output_async( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> str: input = request.input output = input if input in self._special_cases: output = self._special_cases[input](input, context) if iscoroutine(output): output = await output return output class DummyService(dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin): """A gRPC service used for testing. Args: special_cases: A dictionary where the keys are strings, and the values are functions that take and return strings. The functions can also raise exceptions. When the Execute method is given a string in the dict, it will call the function with that string instead, and return the result. This allows testing special cases, like raising exceptions. """ def __init__( self, special_cases: Dict[str, SpecialCaseFunction], ): self._special_cases = special_cases def Execute( self, request: DummyRequest, context: grpc.ServicerContext ) -> DummyResponse: """Echo the input, or take on of the special cases actions.""" return DummyResponse(output=self._get_output(request, context)) def ExecuteClientStream( self, request_iter: Iterable[DummyRequest], context: grpc.ServicerContext ) -> DummyResponse: """Iterate over the input and concatenates the strings into the output.""" output = "".join(self._get_output(request, context) for request in request_iter) return DummyResponse(output=output) def ExecuteServerStream( self, request: DummyRequest, context: grpc.ServicerContext ) -> Iterable[DummyResponse]: """Stream one character at a time from the input.""" for c in self._get_output(request, context): yield DummyResponse(output=c) def ExecuteClientServerStream( self, request_iter: Iterable[DummyRequest], context: grpc.ServicerContext ) -> Iterable[DummyResponse]: """Stream input to output.""" for request in request_iter: yield DummyResponse(output=self._get_output(request, context)) class AsyncDummyService(dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin): """A gRPC service used for testing, similar to DummyService except async. See DummyService for more info. """ def __init__( self, special_cases: Dict[str, SpecialCaseFunction], ): self._special_cases = special_cases async def Execute( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> DummyResponse: """Echo the input, or take on of the special cases actions.""" return DummyResponse(output=await self._get_output_async(request, context)) async def ExecuteClientStream( self, request_iter: AsyncIterable[DummyRequest], context: grpc_aio.ServicerContext, ) -> DummyResponse: """Iterate over the input and concatenates the strings into the output.""" output = "".join([ await self._get_output_async(request, context) async for request in request_iter ]) # noqa: E501 return DummyResponse(output=output) async def ExecuteServerStream( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> AsyncGenerator[DummyResponse, None]: """Stream one character at a time from the input.""" for c in await self._get_output_async(request, context): yield DummyResponse(output=c) async def ExecuteClientServerStream( self, request_iter: AsyncIterable[DummyRequest], context: grpc_aio.ServicerContext, ) -> AsyncGenerator[DummyResponse, None]: """Stream input to output.""" async for request in request_iter: yield DummyResponse(output=await self._get_output_async(request, context)) class AsyncReadWriteDummyService( dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin ): """Similar to AsyncDummyService except uses the read / write API. See DummyService for more info. """ def __init__( self, special_cases: Dict[str, SpecialCaseFunction], ): self._special_cases = special_cases async def Execute( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> DummyResponse: """Echo the input, or take on of the special cases actions.""" return DummyResponse(output=await self._get_output_async(request, context)) async def ExecuteClientStream( self, unused_request: Any, context: grpc_aio.ServicerContext, ) -> DummyResponse: """Iterate over the input and concatenates the strings into the output.""" output = [] while True: request = await context.read() if request == grpc_aio.EOF: break output.append(await self._get_output_async(request, context)) return DummyResponse(output="".join(output)) async def ExecuteServerStream( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> None: """Stream one character at a time from the input.""" for c in await self._get_output_async(request, context): await context.write(DummyResponse(output=c)) async def ExecuteClientServerStream( self, request_iter: AsyncIterable[DummyRequest], context: grpc_aio.ServicerContext, ) -> None: """Stream input to output.""" while True: request = await context.read() if request == grpc_aio.EOF: break await context.write( DummyResponse(output=await self._get_output_async(request, context)) ) @contextmanager def dummy_client( special_cases: Dict[str, SpecialCaseFunction], interceptors: Optional[List[ServerInterceptor]] = None, client_interceptors: Optional[List[ClientInterceptor]] = None, aio_server: bool = False, aio_client: bool = False, aio_read_write: bool = False, ): """A context manager that returns a gRPC client connected to a DummyService.""" # Sanity check that the interceptors are async if using an async server, # otherwise the tests will just hang. for intr in interceptors or []: if aio_server != isinstance(intr, AsyncServerInterceptor): raise TypeError("Set aio_server correctly") with dummy_channel( special_cases, interceptors, client_interceptors, aio_server=aio_server, aio_client=aio_client, aio_read_write=aio_read_write, ) as channel: client = dummy_pb2_grpc.DummyServiceStub(channel) yield client @contextmanager def dummy_channel( special_cases: Dict[str, SpecialCaseFunction], interceptors: Optional[List[ServerInterceptor]] = None, client_interceptors: Optional[List[ClientInterceptor]] = None, aio_server: bool = False, aio_client: bool = False, aio_read_write: bool = False, ): """A context manager that returns a gRPC channel connected to a DummyService.""" if not interceptors: interceptors = [] if aio_server: service = ( AsyncReadWriteDummyService(special_cases) if aio_read_write else AsyncDummyService(special_cases) ) aio_loop = asyncio.new_event_loop() aio_thread = _AsyncServerThread( aio_loop, service, interceptors, ) aio_thread.start() aio_thread.wait_for_server() port = aio_thread.port else: dummy_service = DummyService(special_cases) server = grpc.server( futures.ThreadPoolExecutor(max_workers=1), interceptors=interceptors ) dummy_pb2_grpc.add_DummyServiceServicer_to_server(dummy_service, server) port = server.add_insecure_port("localhost:0") server.start() channel_descriptor = f"localhost:{port}" if aio_client: channel = grpc_aio.insecure_channel(channel_descriptor) # Client interceptors might work, but I haven't tested them yet. if client_interceptors: raise TypeError("Client interceptors not supported with async channel") # We close the channel in _AsyncServerThread because we need to await # it, and doing that in this thread is problematic because dummy_client # isn't always used in an async context. We could get around that by # creating a new loop or something, but will be lazy and use the server # thread / loop for now. if not aio_server: raise ValueError("aio_server must be True if aio_client is True") aio_thread.async_channel = channel else: channel = grpc.insecure_channel(channel_descriptor) if client_interceptors: channel = grpc.intercept_channel(channel, *client_interceptors) try: yield channel finally: if not aio_client: # async channel is closed by _AsyncServerThread channel.close() if aio_server: aio_thread.stop() aio_thread.join() else: server.stop(None) class _AsyncServerThread(Thread): port: int = 0 async_channel = None def __init__( self, loop: asyncio.AbstractEventLoop, service, interceptors: List[AsyncServerInterceptor], ): super().__init__() self.__loop = loop self.__service = service self.__interceptors = interceptors self.__started = Event() def run(self): asyncio.set_event_loop(self.__loop) self.__loop.run_until_complete(self.__run_server()) async def __run_server(self): self.__server = grpc_aio.server(interceptors=tuple(self.__interceptors)) dummy_pb2_grpc.add_DummyServiceServicer_to_server(self.__service, self.__server) self.port = self.__server.add_insecure_port("localhost:0") await self.__server.start() self.__started.set() await self.__server.wait_for_termination() if self.async_channel: await self.async_channel.close() def wait_for_server(self): self.__started.wait() def stop(self): self.__loop.call_soon_threadsafe( lambda: asyncio.ensure_future(self.__shutdown()) ) async def __shutdown(self) -> None: await self.__server.stop(None) self.__loop.stop() ================================================ FILE: src/grpc_interceptor/testing/protos/__init__.py ================================================ """Protobuf definitions for testing.""" ================================================ FILE: src/grpc_interceptor/testing/protos/dummy.proto ================================================ syntax = "proto3"; message DummyRequest { string input = 1; } message DummyResponse { string output = 1; } service DummyService { rpc Execute (DummyRequest) returns (DummyResponse); rpc ExecuteClientStream (stream DummyRequest) returns (DummyResponse); rpc ExecuteServerStream (DummyRequest) returns (stream DummyResponse); rpc ExecuteClientServerStream (stream DummyRequest) returns (stream DummyResponse); } ================================================ FILE: src/grpc_interceptor/testing/protos/dummy_pb2.py ================================================ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: dummy.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( b'\n\x0b\x64ummy.proto"\x1d\n\x0c\x44ummyRequest\x12\r\n\x05input\x18\x01 \x01(\t"\x1f\n\rDummyResponse\x12\x0e\n\x06output\x18\x01 \x01(\t2\xe8\x01\n\x0c\x44ummyService\x12(\n\x07\x45xecute\x12\r.DummyRequest\x1a\x0e.DummyResponse\x12\x36\n\x13\x45xecuteClientStream\x12\r.DummyRequest\x1a\x0e.DummyResponse(\x01\x12\x36\n\x13\x45xecuteServerStream\x12\r.DummyRequest\x1a\x0e.DummyResponse0\x01\x12>\n\x19\x45xecuteClientServerStream\x12\r.DummyRequest\x1a\x0e.DummyResponse(\x01\x30\x01\x62\x06proto3' ) _DUMMYREQUEST = DESCRIPTOR.message_types_by_name["DummyRequest"] _DUMMYRESPONSE = DESCRIPTOR.message_types_by_name["DummyResponse"] DummyRequest = _reflection.GeneratedProtocolMessageType( "DummyRequest", (_message.Message,), { "DESCRIPTOR": _DUMMYREQUEST, "__module__": "dummy_pb2" # @@protoc_insertion_point(class_scope:DummyRequest) }, ) _sym_db.RegisterMessage(DummyRequest) DummyResponse = _reflection.GeneratedProtocolMessageType( "DummyResponse", (_message.Message,), { "DESCRIPTOR": _DUMMYRESPONSE, "__module__": "dummy_pb2" # @@protoc_insertion_point(class_scope:DummyResponse) }, ) _sym_db.RegisterMessage(DummyResponse) _DUMMYSERVICE = DESCRIPTOR.services_by_name["DummyService"] if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _DUMMYREQUEST._serialized_start = 15 _DUMMYREQUEST._serialized_end = 44 _DUMMYRESPONSE._serialized_start = 46 _DUMMYRESPONSE._serialized_end = 77 _DUMMYSERVICE._serialized_start = 80 _DUMMYSERVICE._serialized_end = 312 # @@protoc_insertion_point(module_scope) ================================================ FILE: src/grpc_interceptor/testing/protos/dummy_pb2.pyi ================================================ # @generated by generate_proto_mypy_stubs.py. Do not edit! import sys from google.protobuf.descriptor import ( Descriptor as google___protobuf___descriptor___Descriptor, FileDescriptor as google___protobuf___descriptor___FileDescriptor, ) from google.protobuf.message import Message as google___protobuf___message___Message from typing import ( Optional as typing___Optional, Text as typing___Text, ) from typing_extensions import Literal as typing_extensions___Literal builtin___bool = bool builtin___bytes = bytes builtin___float = float builtin___int = int DESCRIPTOR: google___protobuf___descriptor___FileDescriptor = ... class DummyRequest(google___protobuf___message___Message): DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... input: typing___Text = ... def __init__( self, *, input: typing___Optional[typing___Text] = None, ) -> None: ... def ClearField( self, field_name: typing_extensions___Literal["input", b"input"] ) -> None: ... type___DummyRequest = DummyRequest class DummyResponse(google___protobuf___message___Message): DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... output: typing___Text = ... def __init__( self, *, output: typing___Optional[typing___Text] = None, ) -> None: ... def ClearField( self, field_name: typing_extensions___Literal["output", b"output"] ) -> None: ... type___DummyResponse = DummyResponse ================================================ FILE: src/grpc_interceptor/testing/protos/dummy_pb2_grpc.py ================================================ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc from grpc_interceptor.testing.protos import ( dummy_pb2 as grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2, ) class DummyServiceStub(object): """Missing associated documentation comment in .proto file.""" def __init__(self, channel): """Constructor. Args: channel: A grpc.Channel. """ self.Execute = channel.unary_unary( "/DummyService/Execute", request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, ) self.ExecuteClientStream = channel.stream_unary( "/DummyService/ExecuteClientStream", request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, ) self.ExecuteServerStream = channel.unary_stream( "/DummyService/ExecuteServerStream", request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, ) self.ExecuteClientServerStream = channel.stream_stream( "/DummyService/ExecuteClientServerStream", request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, ) class DummyServiceServicer(object): """Missing associated documentation comment in .proto file.""" def Execute(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def ExecuteClientStream(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def ExecuteServerStream(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def ExecuteClientServerStream(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def add_DummyServiceServicer_to_server(servicer, server): rpc_method_handlers = { "Execute": grpc.unary_unary_rpc_method_handler( servicer.Execute, request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString, response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString, ), "ExecuteClientStream": grpc.stream_unary_rpc_method_handler( servicer.ExecuteClientStream, request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString, response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString, ), "ExecuteServerStream": grpc.unary_stream_rpc_method_handler( servicer.ExecuteServerStream, request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString, response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString, ), "ExecuteClientServerStream": grpc.stream_stream_rpc_method_handler( servicer.ExecuteClientServerStream, request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString, response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( "DummyService", rpc_method_handlers ) server.add_generic_rpc_handlers((generic_handler,)) # This class is part of an EXPERIMENTAL API. class DummyService(object): """Missing associated documentation comment in .proto file.""" @staticmethod def Execute( request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None, ): return grpc.experimental.unary_unary( request, target, "/DummyService/Execute", grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, ) @staticmethod def ExecuteClientStream( request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None, ): return grpc.experimental.stream_unary( request_iterator, target, "/DummyService/ExecuteClientStream", grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, ) @staticmethod def ExecuteServerStream( request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None, ): return grpc.experimental.unary_stream( request, target, "/DummyService/ExecuteServerStream", grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, ) @staticmethod def ExecuteClientServerStream( request_iterator, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None, ): return grpc.experimental.stream_stream( request_iterator, target, "/DummyService/ExecuteClientServerStream", grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString, grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata, ) ================================================ FILE: tests/__init__.py ================================================ """Test suite for grpc-interceptor.""" ================================================ FILE: tests/test_client.py ================================================ """Test cases for the grpc-interceptor base ClientInterceptor.""" from collections import defaultdict import itertools from typing import List, Tuple import grpc import pytest from grpc_interceptor import ClientInterceptor from grpc_interceptor.testing import dummy_client, DummyRequest, raises class MetadataInterceptor(ClientInterceptor): """A test interceptor that injects invocation metadata.""" def __init__(self, metadata: List[Tuple[str, str]]): self._metadata = metadata def intercept(self, method, request_or_iterator, call_details): """Add invocation metadata to request.""" new_details = call_details._replace(metadata=self._metadata) return method(request_or_iterator, new_details) class CodeCountInterceptor(ClientInterceptor): """Test interceptor that counts status codes returned by the server.""" def __init__(self): self.counts = defaultdict(int) def intercept(self, method, request_or_iterator, call_details): """Call continuation and count status codes.""" future = method(request_or_iterator, call_details) self.counts[future.code()] += 1 return future class RetryInterceptor(ClientInterceptor): """Test interceptor that retries failed RPCs.""" def __init__(self, retries): self._retries = retries def intercept(self, method, request_or_iterator, call_details): """Call the continuation and retry up to retries times if it fails.""" tries_remaining = 1 + self._retries while 0 < tries_remaining: future = method(request_or_iterator, call_details) try: future.result() return future except Exception: tries_remaining -= 1 return future class CrashingService: """Special case function that raises a given number of times before succeeding.""" DEFAULT_EXCEPTION = ValueError("oops") def __init__(self, num_crashes, success_value="OK", exception=DEFAULT_EXCEPTION): self._num_crashes = num_crashes self._success_value = success_value self._exception = exception def __call__(self, *args, **kwargs): """Raise the first num_crashes times called, then return success_value.""" if 0 < self._num_crashes: self._num_crashes -= 1 raise self._exception return self._success_value class CachingInterceptor(ClientInterceptor): """A test interceptor that caches responses based on input string.""" def __init__(self): self._cache = {} def intercept(self, method, request_or_iterator, call_details): """Cache responses based on input string.""" if hasattr(request_or_iterator, "__iter__"): request_or_iterator, copy_iterator = itertools.tee(request_or_iterator) cache_key = tuple(r.input for r in copy_iterator) else: cache_key = request_or_iterator.input if cache_key not in self._cache: self._cache[cache_key] = method(request_or_iterator, call_details) return self._cache[cache_key] @pytest.fixture def metadata_string(): """Expected joined metadata string.""" return "this_key:this_value" @pytest.fixture def metadata_client(): """Client with metadata interceptor.""" intr = MetadataInterceptor([("this_key", "this_value")]) interceptors = [intr] special_cases = { "metadata": lambda _, c: ",".join( f"{key}:{value}" for key, value in c.invocation_metadata() ) } with dummy_client( special_cases=special_cases, client_interceptors=interceptors ) as client: yield client def test_metadata_unary(metadata_client, metadata_string): """Invocation metadata should be added to the servicer context.""" unary_output = metadata_client.Execute(DummyRequest(input="metadata")).output assert metadata_string in unary_output def test_metadata_server_stream(metadata_client, metadata_string): """Invocation metadata should be added to the servicer context.""" server_stream_output = [ r.output for r in metadata_client.ExecuteServerStream(DummyRequest(input="metadata")) ] assert metadata_string in "".join(server_stream_output) def test_metadata_client_stream(metadata_client, metadata_string): """Invocation metadata should be added to the servicer context.""" client_stream_input = iter((DummyRequest(input="metadata"),)) client_stream_output = metadata_client.ExecuteClientStream( client_stream_input ).output assert metadata_string in client_stream_output def test_metadata_client_server_stream(metadata_client, metadata_string): """Invocation metadata should be added to the servicer context.""" stream_stream_input = iter((DummyRequest(input="metadata"),)) result = metadata_client.ExecuteClientServerStream(stream_stream_input) stream_stream_output = [r.output for r in result] assert metadata_string in "".join(stream_stream_output) def test_code_counting(): """Access to code on call details works correctly.""" interceptor = CodeCountInterceptor() special_cases = {"error": raises(ValueError("oops"))} with dummy_client( special_cases=special_cases, client_interceptors=[interceptor] ) as client: assert interceptor.counts == {} client.Execute(DummyRequest(input="foo")) assert interceptor.counts == {grpc.StatusCode.OK: 1} with pytest.raises(grpc.RpcError): client.Execute(DummyRequest(input="error")) assert interceptor.counts == {grpc.StatusCode.OK: 1, grpc.StatusCode.UNKNOWN: 1} def test_basic_retry(): """Calling the continuation multiple times should work.""" interceptor = RetryInterceptor(retries=1) special_cases = {"error_once": CrashingService(num_crashes=1)} with dummy_client( special_cases=special_cases, client_interceptors=[interceptor] ) as client: assert client.Execute(DummyRequest(input="error_once")).output == "OK" def test_failed_retry(): """The interceptor can return failed futures.""" interceptor = RetryInterceptor(retries=1) special_cases = {"error_twice": CrashingService(num_crashes=2)} with dummy_client( special_cases=special_cases, client_interceptors=[interceptor] ) as client: with pytest.raises(grpc.RpcError): client.Execute(DummyRequest(input="error_twice")) def test_chaining(): """Chaining interceptors should work.""" retry_interceptor = RetryInterceptor(retries=1) code_count_interceptor = CodeCountInterceptor() interceptors = [retry_interceptor, code_count_interceptor] special_cases = {"error_once": CrashingService(num_crashes=1)} with dummy_client( special_cases=special_cases, client_interceptors=interceptors ) as client: assert code_count_interceptor.counts == {} assert client.Execute(DummyRequest(input="error_once")).output == "OK" assert code_count_interceptor.counts == { grpc.StatusCode.OK: 1, grpc.StatusCode.UNKNOWN: 1, } def test_caching(): """Caching calls (not calling the continuation) should work.""" caching_interceptor = CachingInterceptor() # Use this to test how many times the continuation is called. code_count_interceptor = CodeCountInterceptor() interceptors = [caching_interceptor, code_count_interceptor] with dummy_client(special_cases={}, client_interceptors=interceptors) as client: assert code_count_interceptor.counts == {} assert client.Execute(DummyRequest(input="hello")).output == "hello" assert code_count_interceptor.counts == {grpc.StatusCode.OK: 1} assert client.Execute(DummyRequest(input="hello")).output == "hello" assert code_count_interceptor.counts == {grpc.StatusCode.OK: 1} assert client.Execute(DummyRequest(input="goodbye")).output == "goodbye" assert code_count_interceptor.counts == {grpc.StatusCode.OK: 2} # Try streaming requests inputs = ["foo", "bar"] input_iter = (DummyRequest(input=input) for input in inputs) assert client.ExecuteClientStream(input_iter).output == "foobar" assert code_count_interceptor.counts == {grpc.StatusCode.OK: 3} input_iter = (DummyRequest(input=input) for input in inputs) assert client.ExecuteClientStream(input_iter).output == "foobar" assert code_count_interceptor.counts == {grpc.StatusCode.OK: 3} ================================================ FILE: tests/test_exception_to_status.py ================================================ """Test cases for ExceptionToStatusInterceptor.""" import re from typing import Any, List, Optional, Union import grpc from grpc import aio as grpc_aio import pytest from grpc_interceptor import exceptions as gx from grpc_interceptor.exception_to_status import ( AsyncExceptionToStatusInterceptor, ExceptionToStatusInterceptor, ) from grpc_interceptor.testing import dummy_client, DummyRequest, raises class NonGrpcException(Exception): """An exception that does not derive from GrpcException.""" TEST_STATUS_CODE = grpc.StatusCode.DATA_LOSS TEST_DETAILS = "Here's some custom details" class ExtendedExceptionToStatusInterceptor(ExceptionToStatusInterceptor): """A test case for extending ExceptionToStatusInterceptor.""" def __init__(self): self.caught_custom_exception = False def handle_exception(self, ex, request_or_iterator, context, method_name): """Handles NonGrpcException in a special way.""" if isinstance(ex, NonGrpcException): self.caught_custom_exception = True context.abort( NonGrpcException.TEST_STATUS_CODE, NonGrpcException.TEST_DETAILS ) else: super().handle_exception(ex, request_or_iterator, context, method_name) class AsyncExtendedExceptionToStatusInterceptor(AsyncExceptionToStatusInterceptor): """A test case for extending AsyncExceptionToStatusInterceptor.""" def __init__(self): self.caught_custom_exception = False async def handle_exception(self, ex, request_or_iterator, context, method_name): """Handles NonGrpcException in a special way.""" if isinstance(ex, NonGrpcException): self.caught_custom_exception = True await context.abort( NonGrpcException.TEST_STATUS_CODE, NonGrpcException.TEST_DETAILS ) else: await super().handle_exception( ex, request_or_iterator, context, method_name ) def _get_interceptors( aio: bool, status_on_unknown_exception: Optional[grpc.StatusCode] = None ) -> List[Union[ExceptionToStatusInterceptor, AsyncExceptionToStatusInterceptor]]: return ( [ AsyncExceptionToStatusInterceptor( status_on_unknown_exception=status_on_unknown_exception ) ] if aio else [ ExceptionToStatusInterceptor( status_on_unknown_exception=status_on_unknown_exception ) ] ) def test_repr(): """repr() should display the class name, status code, and details.""" assert ( repr(gx.GrpcException(details="oops")) == "GrpcException(status_code=UNKNOWN, details='oops')" ) assert ( repr(gx.GrpcException(status_code=grpc.StatusCode.NOT_FOUND, details="oops")) == "GrpcException(status_code=NOT_FOUND, details='oops')" ) assert ( repr(gx.NotFound(details="?")) == "NotFound(status_code=NOT_FOUND, details='?')" ) def test_status_string(): """status_string should be the string version of the status code.""" assert gx.GrpcException().status_string == "UNKNOWN" assert ( gx.GrpcException(status_code=grpc.StatusCode.NOT_FOUND).status_string == "NOT_FOUND" ) assert gx.NotFound().status_string == "NOT_FOUND" @pytest.mark.parametrize("aio", [False, True]) def test_no_exception(aio): """An RPC with no exceptions should work as if the interceptor wasn't there.""" interceptors = _get_interceptors(aio) with dummy_client( special_cases={}, interceptors=interceptors, aio_server=aio ) as client: assert client.Execute(DummyRequest(input="foo")).output == "foo" @pytest.mark.parametrize("aio", [False, True]) def test_custom_details(aio): """We can set custom details.""" interceptors = _get_interceptors(aio) special_cases = {"error": raises(gx.NotFound(details="custom"))} with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio ) as client: assert ( client.Execute(DummyRequest(input="foo")).output == "foo" ) # Test a happy path too with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == grpc.StatusCode.NOT_FOUND assert e.value.details() == "custom" @pytest.mark.parametrize("aio", [False, True]) def test_non_grpc_exception(aio): """Exceptions other than GrpcExceptions are ignored.""" interceptors = _get_interceptors(aio) special_cases = {"error": raises(ValueError("oops"))} with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio ) as client: with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == grpc.StatusCode.UNKNOWN @pytest.mark.parametrize("aio", [False, True]) def test_non_grpc_exception_with_override(aio): """We can set a custom status code when non-GrpcExceptions are raised.""" interceptors = _get_interceptors( aio, status_on_unknown_exception=grpc.StatusCode.INTERNAL ) special_cases = {"error": raises(ValueError("oops"))} with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio ) as client: with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == grpc.StatusCode.INTERNAL assert re.fullmatch(r"ValueError\('oops',?\)", e.value.details()) @pytest.mark.parametrize("aio", [False, True]) def test_aborted_context(aio): """If the context is aborted, the exception is propagated.""" def error(request: Any, context: grpc.ServicerContext) -> None: context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted') async def async_error(request: Any, context: grpc_aio.ServicerContext) -> None: await context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted') interceptors = _get_interceptors(aio, grpc.StatusCode.INTERNAL) special_cases = { "error": async_error if aio else error } with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio ) as client: with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED assert re.fullmatch(r"resource exhausted", e.value.details()) def test_override_with_ok(): """We cannot set the default status code to OK.""" with pytest.raises(ValueError): ExceptionToStatusInterceptor(status_on_unknown_exception=grpc.StatusCode.OK) with pytest.raises(ValueError): AsyncExceptionToStatusInterceptor( status_on_unknown_exception=grpc.StatusCode.OK ) @pytest.mark.parametrize("aio", [False, True]) def test_all_exceptions(aio): """Every gRPC status code is represented, and they each are unique. Make sure we aren't missing any status codes, and that we didn't copy paste the same status code or details into two different classes. """ interceptors = _get_interceptors(aio) all_status_codes = {sc for sc in grpc.StatusCode if sc != grpc.StatusCode.OK} seen_codes = set() seen_details = set() for sc in all_status_codes: ex = getattr(gx, _snake_to_camel(sc.name)) assert ex special_cases = {"error": raises(ex())} with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio ) as client: with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == sc assert e.value.details() == ex.details seen_codes.add(sc) seen_details.add(ex.details) assert seen_codes == all_status_codes assert len(seen_details) == len(all_status_codes) @pytest.mark.parametrize("aio", [False, True]) def test_exception_in_streaming_response(aio): """Exceptions are raised correctly from streaming responses.""" interceptors = _get_interceptors(aio) with dummy_client( special_cases={"error": raises(gx.NotFound("not found!"))}, interceptors=interceptors, aio_server=aio, ) as client: with pytest.raises(grpc.RpcError) as e: list(client.ExecuteServerStream(DummyRequest(input="error"))) assert e.value.code() == grpc.StatusCode.NOT_FOUND assert e.value.details() == "not found!" def _snake_to_camel(s: str) -> str: return "".join([p.title() for p in s.split("_")]) def test_not_ok(): """We cannot create a GrpcException with an OK status code.""" with pytest.raises(ValueError): gx.GrpcException(status_code=grpc.StatusCode.OK) @pytest.mark.parametrize("aio", [False, True]) def test_extending(aio): """We can extend ExceptionToStatusInterceptor.""" interceptor = ( AsyncExtendedExceptionToStatusInterceptor() if aio else ExtendedExceptionToStatusInterceptor() ) special_cases = {"error": raises(NonGrpcException())} with dummy_client( special_cases=special_cases, interceptors=[interceptor], aio_server=aio ) as client: assert ( client.Execute(DummyRequest(input="foo")).output == "foo" ) # Test a happy path too assert not interceptor.caught_custom_exception with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="error")) assert e.value.code() == NonGrpcException.TEST_STATUS_CODE assert e.value.details() == NonGrpcException.TEST_DETAILS assert interceptor.caught_custom_exception ================================================ FILE: tests/test_server.py ================================================ """Test cases for the grpc-interceptor base ServerInterceptor.""" from collections import defaultdict import grpc import pytest from grpc_interceptor import ( AsyncServerInterceptor, MethodName, parse_method_name, ServerInterceptor, ) from grpc_interceptor.testing import dummy_client, DummyRequest from grpc_interceptor.testing.dummy_client import dummy_channel class CountingInterceptor(ServerInterceptor): """A test interceptor that counts calls and exceptions.""" def __init__(self): self.num_calls = defaultdict(int) self.num_errors = defaultdict(int) def intercept(self, method, request, context, method_name): """Count each call and exception.""" self.num_calls[method_name] += 1 try: return method(request, context) except Exception: self.num_errors[method_name] += 1 raise class AsyncCountingInterceptor(AsyncServerInterceptor): """A test interceptor that counts calls and exceptions.""" def __init__(self): self.num_calls = defaultdict(int) self.num_errors = defaultdict(int) async def intercept(self, method, request, context, method_name): """Count each call and exception.""" self.num_calls[method_name] += 1 try: return await method(request, context) except Exception: self.num_errors[method_name] += 1 raise class SideEffectInterceptor(ServerInterceptor): """A test interceptor that calls a function for the side effect.""" def __init__(self, side_effect): self._side_effect = side_effect def intercept(self, method, request, context, method_name): """Call the side effect and then the RPC method.""" self._side_effect() return method(request, context) class AsyncSideEffectInterceptor(AsyncServerInterceptor): """A test interceptor that calls a function for the side effect.""" def __init__(self, side_effect): self._side_effect = side_effect async def intercept(self, method, request, context, method_name): """Call the side effect and then the RPC method.""" self._side_effect() return await method(request, context) class UppercasingInterceptor(ServerInterceptor): """A test interceptor that modifies the request by uppercasing the input field.""" def intercept(self, method, request, context, method_name): """Uppercases request.input.""" request.input = request.input.upper() return method(request, context) class AsyncUppercasingInterceptor(AsyncServerInterceptor): """A test interceptor that modifies the request by uppercasing the input field.""" async def intercept(self, method, request, context, method_name): """Uppercases request.input.""" request.input = request.input.upper() return await method(request, context) class AbortingInterceptor(ServerInterceptor): """A test interceptor that aborts before calling the handler.""" def __init__(self, message): self._message = message def intercept(self, method, request, context, method_name): """Calls abort.""" context.abort(grpc.StatusCode.ABORTED, self._message) class AsyncAbortingInterceptor(AsyncServerInterceptor): """A test interceptor that aborts before calling the handler.""" def __init__(self, message): self._message = message async def intercept(self, method, request, context, method_name): """Calls abort.""" await context.abort(grpc.StatusCode.ABORTED, self._message) @pytest.mark.parametrize("aio", [False, True]) def test_call_counts(aio): """The counts should be correct.""" intr_type = AsyncCountingInterceptor if aio else CountingInterceptor intr = intr_type() interceptors = [intr] special_cases = {"error": lambda r, c: 1 / 0} with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio ) as client: assert client.Execute(DummyRequest(input="foo")).output == "foo" assert len(intr.num_calls) == 1 assert intr.num_calls["/DummyService/Execute"] == 1 assert len(intr.num_errors) == 0 with pytest.raises(grpc.RpcError): client.Execute(DummyRequest(input="error")) assert len(intr.num_calls) == 1 assert intr.num_calls["/DummyService/Execute"] == 2 assert len(intr.num_errors) == 1 assert intr.num_errors["/DummyService/Execute"] == 1 @pytest.mark.parametrize("aio", [False, True]) def test_interceptor_chain(aio): """Interceptors are called in the right order.""" trace = [] intr_type = AsyncSideEffectInterceptor if aio else SideEffectInterceptor interceptor1 = intr_type(lambda: trace.append(1)) interceptor2 = intr_type(lambda: trace.append(2)) with dummy_client( special_cases={}, interceptors=[interceptor1, interceptor2], aio_server=aio ) as client: assert client.Execute(DummyRequest(input="test")).output == "test" assert trace == [1, 2] @pytest.mark.parametrize("aio", [False, True]) def test_modifying_interceptor(aio): """Interceptors can modify requests.""" intr_type = AsyncUppercasingInterceptor if aio else UppercasingInterceptor interceptor = intr_type() with dummy_client( special_cases={}, interceptors=[interceptor], aio_server=aio ) as client: assert client.Execute(DummyRequest(input="test")).output == "TEST" @pytest.mark.parametrize("aio", [False, True]) def test_aborting_interceptor(aio): """context.abort called in an interceptor works.""" intr_type = AsyncAbortingInterceptor if aio else AbortingInterceptor interceptor = intr_type("oh no") with dummy_client( special_cases={}, interceptors=[interceptor], aio_server=aio ) as client: with pytest.raises(grpc.RpcError) as e: client.Execute(DummyRequest(input="test")) assert e.value.code() == grpc.StatusCode.ABORTED assert e.value.details() == "oh no" @pytest.mark.parametrize("aio", [False, True]) def test_method_not_found(aio): """Calling undefined endpoints should return Unimplemented. Interceptors are not invoked when the RPC call is not handled. """ intr_type = AsyncCountingInterceptor if aio else CountingInterceptor intr = intr_type() interceptors = [intr] with dummy_channel( special_cases={}, interceptors=interceptors, aio_server=aio ) as channel: with pytest.raises(grpc.RpcError) as e: channel.unary_unary( "/DummyService/Unimplemented", )(b"") assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED assert len(intr.num_calls) == 0 assert len(intr.num_errors) == 0 def test_method_name(): """Fields are correct and fully_qualified_service work.""" mn = MethodName("foo.bar", "SearchService", "Search") assert mn.package == "foo.bar" assert mn.service == "SearchService" assert mn.method == "Search" assert mn.fully_qualified_service == "foo.bar.SearchService" def test_empty_package_method_name(): """fully_qualified_service works when there's no package.""" mn = MethodName("", "SearchService", "Search") assert mn.fully_qualified_service == "SearchService" def test_parse_method_name(): """parse_method_name parses fields when there's a package.""" mn = parse_method_name("/foo.bar.SearchService/Search") assert mn.package == "foo.bar" assert mn.service == "SearchService" assert mn.method == "Search" def test_parse_empty_package(): """parse_method_name works with no package.""" mn = parse_method_name("/SearchService/Search") assert mn.package == "" assert mn.service == "SearchService" assert mn.method == "Search" ================================================ FILE: tests/test_streaming.py ================================================ """Test cases for streaming RPCs.""" import sys import grpc import pytest from grpc_interceptor import AsyncServerInterceptor, ServerInterceptor from grpc_interceptor.testing import dummy_client, DummyRequest class StreamingInterceptor(ServerInterceptor): """A test interceptor that streams.""" def intercept(self, method, request, context, method_name): """Doesn't do anything; just make sure we handle streaming RPCs.""" return method(request, context) class AsyncStreamingInterceptor(AsyncServerInterceptor): """A test interceptor that streams.""" async def intercept(self, method, request, context, method_name): """Doesn't do anything; just make sure we handle streaming RPCs.""" response_or_iterator = method(request, context) if hasattr(response_or_iterator, "__aiter__"): return response_or_iterator else: return await response_or_iterator class ServerStreamingLoggingInterceptor(ServerInterceptor): """A test interceptor that logs a stream of server responses.""" def __init__(self): self._logs = [] def intercept(self, method, request, context, method_name): """Log each response object and re-yield.""" for resp in method(request, context): self._logs.append(resp.output) yield resp class AsyncServerStreamingLoggingInterceptor(AsyncServerInterceptor): """A test interceptor that logs a stream of server responses.""" def __init__(self): self._logs = [] async def intercept(self, method, request, context, method_name): """Log each response object and re-yield.""" async for resp in method(request, context): self._logs.append(resp.output) yield resp class ServerOmniLoggingInterceptor(ServerInterceptor): """A test interceptor that logs both unary and streaming server responses.""" def __init__(self): self._logs = [] def _log_and_yield(self, iterator): logs = [] for resp in iterator: logs.append(resp.output) yield resp self._logs.append(logs) def intercept(self, method, request, context, method_name): """Log each response object and re-yield.""" response_or_iterator = method(request, context) if hasattr(response_or_iterator, "__iter__"): return self._log_and_yield(response_or_iterator) else: self._logs.append(response_or_iterator.output) return response_or_iterator class AsyncServerOmniLoggingInterceptor(AsyncServerInterceptor): """A test interceptor that logs both unary and streaming server responses.""" def __init__(self): self._logs = [] async def _log_and_yield(self, iterator): logs = [] async for resp in iterator: logs.append(resp.output) yield resp self._logs.append(logs) async def intercept(self, method, request, context, method_name): """Log each response object and re-yield.""" response_or_iterator = method(request, context) if hasattr(response_or_iterator, "__aiter__"): return self._log_and_yield(response_or_iterator) else: response_or_iterator = await response_or_iterator self._logs.append(response_or_iterator.output) return response_or_iterator class ClientStreamingLoggingInterceptor(ServerInterceptor): """A test interceptor that logs a stream of server requests.""" def __init__(self): self._logs = [] def _log_and_yield(self, request): for r in request: self._logs.append(r.input) yield r def intercept(self, method, request, context, method_name): """Log each request object and pass through.""" req = self._log_and_yield(request) return method(req, context) class AsyncClientStreamingLoggingInterceptor(AsyncServerInterceptor): """A test interceptor that logs a stream of server requests.""" def __init__(self): self._logs = [] async def _log_and_yield(self, request): async for r in request: self._logs.append(r.input) yield r async def intercept(self, method, request, context, method_name): """Log each request object and pass through.""" req = self._log_and_yield(request) return await method(req, context) @pytest.mark.parametrize("aio", [False, True]) @pytest.mark.parametrize("aio_rw", [False, True]) def test_client_streaming(aio, aio_rw): """Client streaming should work.""" intr = AsyncStreamingInterceptor() if aio else StreamingInterceptor() interceptors = [intr] special_cases = {"error": lambda r, c: 1 / 0} with dummy_client( special_cases=special_cases, interceptors=interceptors, aio_server=aio, aio_read_write=aio_rw, ) as client: inputs = ["foo", "bar"] input_iter = (DummyRequest(input=input) for input in inputs) assert client.ExecuteClientStream(input_iter).output == "foobar" inputs = ["foo", "error"] input_iter = (DummyRequest(input=input) for input in inputs) with pytest.raises(grpc.RpcError): client.ExecuteClientStream(input_iter) @pytest.mark.parametrize("aio", [False, True]) def test_server_streaming(aio): """Server streaming should work.""" intr = AsyncStreamingInterceptor() if aio else StreamingInterceptor() interceptors = [intr] with dummy_client( special_cases={}, interceptors=interceptors, aio_server=aio ) as client: output = [ r.output for r in client.ExecuteServerStream(DummyRequest(input="foo")) ] assert output == ["f", "o", "o"] @pytest.mark.parametrize("aio", [False, True]) def test_client_server_streaming(aio): """Bidirectional streaming should work.""" intr = AsyncStreamingInterceptor() if aio else StreamingInterceptor() interceptors = [intr] with dummy_client( special_cases={}, interceptors=interceptors, aio_server=aio ) as client: inputs = ["foo", "bar"] input_iter = (DummyRequest(input=input) for input in inputs) response = client.ExecuteClientServerStream(input_iter) assert [r.output for r in response] == inputs @pytest.mark.parametrize("aio", [False, True]) def test_interceptor_iterates_server_streaming(aio): """The iterator should be able to iterate over streamed server responses.""" intr = ( AsyncServerStreamingLoggingInterceptor() if aio else ServerStreamingLoggingInterceptor() ) interceptors = [intr] with dummy_client( special_cases={}, interceptors=interceptors, aio_server=aio ) as client: output = [ r.output for r in client.ExecuteServerStream(DummyRequest(input="foo")) ] assert output == ["f", "o", "o"] assert intr._logs == ["f", "o", "o"] @pytest.mark.parametrize("aio", [False, True]) def test_interceptor_handles_both_unary_and_streaming(aio): """The iterator should be able to iterate over streamed server responses.""" intr = ( AsyncServerOmniLoggingInterceptor() if aio else ServerOmniLoggingInterceptor() ) interceptors = [intr] with dummy_client( special_cases={}, interceptors=interceptors, aio_server=aio ) as client: output = [ r.output for r in client.ExecuteServerStream(DummyRequest(input="foo")) ] assert output == ["f", "o", "o"] assert intr._logs == [["f", "o", "o"]] r = client.Execute(DummyRequest(input="bar")) assert r.output == "bar" assert intr._logs == [["f", "o", "o"], "bar"] @pytest.mark.parametrize("aio", [False, True]) def test_client_log_streaming(aio): """Client streaming should work when re-yielding.""" intr = ( AsyncClientStreamingLoggingInterceptor() if aio else ClientStreamingLoggingInterceptor() ) interceptors = [intr] with dummy_client( special_cases={}, interceptors=interceptors, aio_server=aio ) as client: inputs = ["foo", "bar"] input_iter = (DummyRequest(input=input) for input in inputs) assert client.ExecuteClientStream(input_iter).output == "foobar" assert intr._logs == inputs @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python 3.7") @pytest.mark.asyncio async def test_client_streaming_write_method(): """Client streaming should work when using write().""" intr = AsyncClientStreamingLoggingInterceptor() interceptors = [intr] with dummy_client( special_cases={}, interceptors=interceptors, aio_server=True, aio_client=True ) as client: call = client.ExecuteClientStream() await call.write(DummyRequest(input="foo")) await call.write(DummyRequest(input="bar")) await call.done_writing() response = await call assert response.output == "foobar" assert intr._logs == ["foo", "bar"]