[
  {
    "path": ".flake8",
    "content": "[flake8]\nselect = B,B9,C,D,E,F,I,S,W\nexclude = *_pb2.py,*_pb2_grpc.py\nignore = D107,W503\n\napplication-import-names = grpc_interceptor,tests\nimport-order-style = google\n\nmax-complexity = 10\nmax-line-length = 88\n\n# asserts in tests are OK\nper-file-ignores = tests/*:S101\n\ndocstring-convention = google\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n### What versions of the following are you using?\n\n* `python`:\n* `grpc-interceptor`:\n* `grpcio`:\n* `protobuf`:\n\n### What operating system (Linux, Windows,...) and version?\n\n\n### What did you do?\n\nPlease provide specific steps to reproduce the bug.\n\n### What did you expect to see?\n\n\n### What did you see instead?\n\nPlease include any information that's helpful for debugging (full error message, exception listing, stack trace, logs).\n\n### Anything else we should know about your project / environment?\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/something-else.md",
    "content": "---\nname: Something else\nabout: Feature requests and other things\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n\n"
  },
  {
    "path": ".github/release-drafter.yml",
    "content": "categories:\n    - title: ':boom: Breaking Changes'\n      label: 'breaking'\n    - title: ':package: Build System'\n      label: 'build'\n    - title: ':construction_worker: Continuous Integration'\n      label: 'ci'\n    - title: ':books: Documentation'\n      label: 'documentation'\n    - title: ':rocket: Features'\n      label: 'enhancement'\n    - title: ':beetle: Fixes'\n      label: 'bug'\n    - title: ':racehorse: Performance'\n      label: 'performance'\n    - title: ':hammer: Refactoring'\n      label: 'refactoring'\n    - title: ':fire: Removals and Deprecations'\n      label: 'removal'\n    - title: ':lipstick: Style'\n      label: 'style'\n    - title: ':rotating_light: Testing'\n      label: 'testing'\ntemplate: |\n    ## What’s Changed\n  \n    $CHANGES\n"
  },
  {
    "path": ".github/workflows/coverage.yml",
    "content": "name: Coverage\non: [push, pull_request]\njobs:\n    coverage:\n        runs-on: ubuntu-latest\n        steps:\n        - uses: actions/checkout@v2\n        - uses: actions/setup-python@v4\n          with:\n              python-version: '3.11'\n              architecture: x64\n        - run: pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9\n        - run: nox --sessions tests-3.11\n        - uses: codecov/codecov-action@v3\n          with:\n            token: ${{ secrets.CODECOV_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/mindeps.yml",
    "content": "name: Minimum Dependencies\non: [push, pull_request]\njobs:\n    mindeps:\n        runs-on: ubuntu-latest\n        container: python:3.7-slim\n        steps:\n        - name: Installing dependencies\n          run: |\n            pip install --upgrade pip &&\n            pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9\n        - uses: actions/checkout@v2\n        - run: |\n            cd \"$GITHUB_WORKSPACE\" &&\n            mkdir .nox &&\n            nox --sessions mindeps\n"
  },
  {
    "path": ".github/workflows/release-drafter.yml",
    "content": "name: Release Drafter\non:\n    push:\n        branches:\n        - master\njobs:\n    draft_release:\n        runs-on: ubuntu-latest\n        steps:\n        - uses: release-drafter/release-drafter@v5.6.1\n          env:\n              GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Release\non:\n    release:\n        types: [published]\njobs:\n    build:\n        runs-on: ubuntu-latest\n        steps:\n        - uses: actions/checkout@v2\n        - uses: actions/setup-python@v4\n          with:\n              python-version: '3.9'\n              architecture: x64\n        - run: pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9\n        - run: nox\n        - run: poetry build\n        - uses: actions/upload-artifact@v3\n          with:\n              name: dist\n              path: dist/\n    release:\n        needs: build\n        runs-on: ubuntu-latest\n        permissions:\n            id-token: write\n        environment:\n            name: pypi\n            url: https://pypi.org/p/grpc-interceptor\n        steps:\n        - uses: actions/download-artifact@v3\n          with:\n              name: dist\n              path: dist\n        - name: Publish package distributions to PyPI\n          uses: pypa/gh-action-pypi-publish@release/v1\n"
  },
  {
    "path": ".github/workflows/tests.yml",
    "content": "name: Tests\non: [push, pull_request]\njobs:\n    tests:\n        strategy:\n            fail-fast: false\n            matrix:\n                platform: [ubuntu-latest, macos-latest, windows-latest]\n                python-version: ['3.11', '3.10', '3.9', '3.8', '3.7']\n        name: Python ${{matrix.python-version}} ${{matrix.platform}}\n        runs-on: ${{matrix.platform}}\n        steps:\n        - uses: actions/checkout@v2\n        - uses: actions/setup-python@v4\n          with:\n              python-version: ${{matrix.python-version}}\n              architecture: x64\n        - run: pip install nox==2022.1.7 toml==0.10.2 poetry==1.0.9\n        - run: nox --python ${{matrix.python-version}}\n"
  },
  {
    "path": ".gitignore",
    "content": "/.coverage\n/.idea/\n/.nox/\n/.venv/\n/.vscode/\n/coverage.xml\n/dist/\n/docs/_build/\n/src/*.egg-info/\n__pycache__/\npoetry.toml\n"
  },
  {
    "path": ".readthedocs.yml",
    "content": "version: 2\nsphinx:\n    configuration: docs/conf.py\nformats: all\npython:\n    version: 3.8\n    install:\n    - requirements: docs/requirements.txt\n    - path: .\n      extra_requirements: [testing]\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\nAll notable changes to this project will be documented in this file.\n\nThe format is based roughly on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),\nand this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\n\n## [0.13.1] - 2022-01-29\n### Added\n- Run tests on Python 3.9 and default to that for linting, etc.\n\n### Fixed\n- Raise UNIMPLEMENTED instead of UNKNOWN when no handler is found for a method (thanks Richard Mahlberg!)\n\n## [0.13.0] - 2020-12-27\n### Added\n- Client-side interceptors (thanks Michael Morgan!)\n\n### Changed (breaking)\n- Added a `context` parameter to special case functions in the `testing` module\n\n### Fixed\n- Build issue caused by [pip upgrade](https://github.com/cjolowicz/hypermodern-python/issues/174#issuecomment-745364836)\n- Docs not building correctly in nox\n\n## [0.12.0] - 2020-10-07\n### Added\n- Support for all streaming RPCs\n\n## [0.11.0] - 2020-07-24\n### Added\n- Expose some imports from the top-level\n\n### Changed (breaking)\n- Rename to `ServerInterceptor` (do not intend to make breaking name changes after this)\n\n## [0.10.0] - 2020-07-23\n### Added\n- `status_on_unknown_exception` to `ExceptionToStatusInterceptor`\n- `py.typed` (so `mypy` will type check the package)\n\n### Fixed\n- Allow protobuf version 4.0.x which is coming out soon and is backwards compatible\n- Testing in autodocs\n- Turn on xdoctest\n- Prevent autodoc from outputting default namedtuple docs\n\n### Changed (breaking)\n- Rename `Interceptor` to `ServiceInterceptor`\n\n## [0.9.0] - 2020-07-22\n### Added\n- The `testing` module\n- Some helper functions\n- Improved test coverage\n\n### Fixed\n- Protobuf compatibility improvements\n\n## [0.8.0] - 2020-07-19\n### Added\n- An `Interceptor` base class, to make it easy to define your own service interceptors\n- An `ExceptionToStatusInterceptor` interceptor\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Running Tests\n\nThis will run the unit tests quickly:\n\n```\npoetry install\nmake tests\n```\n\nIt doesn't run the entire test suite. See below for that.\n\n# Making a Pull Request\n\nPlease bump the version number in `pyproject.toml` when you make a pull request. This is needed to give the package a new version.\n\nAlso run lint checks and mypy before pushing. This runs in Github Actions as well, but you'll get faster feedback by running it locally. To do this, run `nox -s lint` and\n`nox -s mypy-3.x`, for whatever version of Python you have installed. For example, if\nyou're using Python 3.9, run `nox -s mypy-3.9`. If you need to make formatting changes,\nyou can run `nox -s black`. Note that `nox` isn't installed via `poetry`, due to\nthe way it works, so you'll need to install it globally. If you don't want to install\n`nox`, you can do all this in docker. For example, you can run this to mount the current\ndirectory into /app and work from there:\n\n```\ndocker run --rm -it --mount type=bind,src=\"$(pwd)\",dst=/app python:3.9 bash\n```\n\n# Adding Tests\n\nAdd both a sync and async version if applicable. You can follow the examples in many\ntests for this. Search for `aio` to find one. Assuming the test applies to both sync\nand async code, it will need to create a different test interceptor depending on the\nvalue of `aio`. Then just remember to pass `aio_server=aio` to `dummy_client`. The\nrest of the test can be the same for both sync and async. This is because the tests\ncreate a client which calls a server. The client doesn't care whether the server is\nsync or async.\n\n# On Changing Dependencies\n\nI want to keep this library very small, not just in terms of its own code, but in terms\nof the code it pulls in. Having many dependencies is a burden to users. It increases\ninstallation time (especially when solving constraints with newer pip or poetry). It\nincreases the likelihood of dependency conflicts, and generally just introduces more\nthat can go wrong. Hence, this library depends on as little as possible, and can\nhopefully stay that way.\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 Dan Hipschman\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Makefile",
    "content": "TEST_PROTOS := src/grpc_interceptor/testing/protos/dummy.proto\nTEST_PROTO_GEN := $(shell echo $(TEST_PROTOS) | sed 's/\\.proto/_pb2.py/g') \\\n\t\t\t\t  $(shell echo $(TEST_PROTOS) | sed 's/\\.proto/_pb2_grpc.py/g')\n\n$(TEST_PROTO_GEN): $(TEST_PROTOS)\n\tcd src && \\\n\tprintf \"%s\\n\" $(TEST_PROTOS) | \\\n\tsed 's|^src/||' | \\\n\txargs poetry run python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. --mypy_out=.\n\n.PHONY: test\ntest: $(TEST_PROTO_GEN)\n\tpoetry run pytest --cov\n\n.PHONY: nox-test\nnox-test: $(TEST_PROTO_GEN)\n\tnox -r\n"
  },
  {
    "path": "README.md",
    "content": "[![Tests](https://github.com/d5h-foss/grpc-interceptor/workflows/Tests/badge.svg)](https://github.com/d5h-foss/grpc-interceptor/actions?workflow=Tests)\n[![Codecov](https://codecov.io/gh/d5h-foss/grpc-interceptor/branch/master/graph/badge.svg)](https://codecov.io/gh/d5h-foss/grpc-interceptor)\n[![Read the Docs](https://readthedocs.org/projects/grpc-interceptor/badge/)](https://grpc-interceptor.readthedocs.io/)\n[![PyPI](https://img.shields.io/pypi/v/grpc-interceptor.svg)](https://pypi.org/project/grpc-interceptor/)\n\n# Summary\n\nSimplified Python gRPC interceptors.\n\nThe Python `grpc` package provides service interceptors, but they're a bit hard to\nuse because of their flexibility. The `grpc` interceptors don't have direct access\nto the request and response objects, or the service context. Access to these are often\ndesired, to be able to log data in the request or response, or set status codes on the\ncontext.\n\n# Installation\n\nTo just get the interceptors (and probably not write your own):\n\n```console\n$ pip install grpc-interceptor\n```\n\nTo also get the testing framework, which is good if you're writing your own interceptors:\n\n```console\n$ pip install grpc-interceptor[testing]\n```\n\n# Usage\n\n## Server Interceptor\n\nTo define your own interceptor (we can use `ExceptionToStatusInterceptor` as an example):\n\n```python\nfrom grpc_interceptor import ServerInterceptor\nfrom grpc_interceptor.exceptions import GrpcException\n\nclass ExceptionToStatusInterceptor(ServerInterceptor):\n    def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> Any:\n        \"\"\"Override this method to implement a custom interceptor.\n         You should call method(request_or_iterator, context) to invoke the\n         next handler (either the RPC method implementation, or the\n         next interceptor in the list).\n         Args:\n             method: The next interceptor, or method implementation.\n             request_or_iterator: The RPC request, as a protobuf message.\n             context: The ServicerContext pass by gRPC to the service.\n             method_name: A string of the form\n                 \"/protobuf.package.Service/Method\"\n         Returns:\n             This should generally return the result of\n             method(request_or_iterator, context), which is typically the RPC\n             method response, as a protobuf message. The interceptor\n             is free to modify this in some way, however.\n         \"\"\"\n        try:\n            return method(request_or_iterator, context)\n        except GrpcException as e:\n            context.set_code(e.status_code)\n            context.set_details(e.details)\n            raise\n```\n\nThen inject your interceptor when you create the `grpc` server:\n\n```python\ninterceptors = [ExceptionToStatusInterceptor()]\nserver = grpc.server(\n    futures.ThreadPoolExecutor(max_workers=10),\n    interceptors=interceptors\n)\n```\n\nTo use `ExceptionToStatusInterceptor`:\n\n```python\nfrom grpc_interceptor.exceptions import NotFound\n\nclass MyService(my_pb2_grpc.MyServiceServicer):\n    def MyRpcMethod(\n        self, request: MyRequest, context: grpc.ServicerContext\n    ) -> MyResponse:\n        thing = lookup_thing()\n        if not thing:\n            raise NotFound(\"Sorry, your thing is missing\")\n        ...\n```\n\nThis results in the gRPC status status code being set to `NOT_FOUND`,\nand the details `\"Sorry, your thing is missing\"`. This saves you the hassle of\ncatching exceptions in your service handler, or passing the context down into\nhelper functions so they can call `context.abort` or `context.set_code`. It allows\nthe more Pythonic approach of just raising an exception from anywhere in the code,\nand having it be handled automatically.\n\n## Client Interceptor\n\nWe will use an invocation metadata injecting interceptor as an example of defining\na client interceptor:\n\n```python\nfrom grpc_interceptor import ClientCallDetails, ClientInterceptor\n\nclass MetadataClientInterceptor(ClientInterceptor):\n\n    def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        call_details: grpc.ClientCallDetails,\n    ):\n        \"\"\"Override this method to implement a custom interceptor.\n\n        This method is called for all unary and streaming RPCs. The interceptor\n        implementation should call `method` using a `grpc.ClientCallDetails` and the\n        `request_or_iterator` object as parameters. The `request_or_iterator`\n        parameter may be type checked to determine if this is a singluar request\n        for unary RPCs or an iterator for client-streaming or client-server streaming\n        RPCs.\n\n        Args:\n            method: A function that proceeds with the invocation by executing the next\n                interceptor in the chain or invoking the actual RPC on the underlying\n                channel.\n            request_or_iterator: RPC request message or iterator of request messages\n                for streaming requests.\n            call_details: Describes an RPC to be invoked.\n\n        Returns:\n            The type of the return should match the type of the return value received\n            by calling `method`. This is an object that is both a\n            `Call <https://grpc.github.io/grpc/python/grpc.html#grpc.Call>`_ for the\n            RPC and a `Future <https://grpc.github.io/grpc/python/grpc.html#grpc.Future>`_.\n\n            The actual result from the RPC can be got by calling `.result()` on the\n            value returned from `method`.\n        \"\"\"\n        new_details = ClientCallDetails(\n            call_details.method,\n            call_details.timeout,\n            [(\"authorization\", \"Bearer mysecrettoken\")],\n            call_details.credentials,\n            call_details.wait_for_ready,\n            call_details.compression,\n        )\n\n        return method(request_or_iterator, new_details)\n```\n\nNow inject your interceptor when you create the ``grpc`` channel:\n\n```python\ninterceptors = [MetadataClientInterceptor()]\nwith grpc.insecure_channel(\"grpc-server:50051\") as channel:\n    channel = grpc.intercept_channel(channel, *interceptors)\n    ...\n```\n\nClient interceptors can also be used to\n[retry RPCs](https://github.com/d5h-foss/grpc-interceptor/blob/4b6bb6a59aae97aec058c0d4072dd19de8f408bc/tests/test_client.py#L39-L56)\nthat fail due to specific errors, or a host of other use cases. There are some basic\napproaches in\n[the tests](https://github.com/d5h-foss/grpc-interceptor/blob/master/tests/test_client.py)\nto get you started.\n\nNote: The `method` in a client interceptor is a `continuation` as described in the\n[client interceptor section of the gRPC docs](https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryClientInterceptor.intercept_unary_unary).\nWhen you invoke the continuation, you get a future back, which resolves to either the\nresult, or exception. This is different than invoking a client stub, which returns the\nresult directly. If the interceptor needs the value returned by the call, or to catch\nexceptions, then you'll need to do `future = method(request_or_iterator, call_details)`,\nfollowed by `future.result()`. Check out the tests for\n[examples](https://github.com/d5h-foss/grpc-interceptor/blob/4b6bb6a59aae97aec058c0d4072dd19de8f408bc/tests/test_client.py#L39-L56).\n\n\n# Documentation\n\nThe examples above showed usage for simple unary-unary RPC calls. For examples of\nstreaming and asyncio RPCs, read the\n[complete documentation here](https://grpc-interceptor.readthedocs.io/).\n\nNote that there is no asyncio client interceptors at the moment, though contributions\nare welcome.\n"
  },
  {
    "path": "docs/conf.py",
    "content": "\"\"\"Sphinx configuration.\"\"\"\n\nimport re\n\n\nproject = \"grpc-interceptor\"\nauthor = \"Dan Hipschman\"\ncopyright = f\"2020, {author}\"\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n]\n\n\ndef setup(app):\n    \"\"\"Sphinx setup.\"\"\"\n    app.connect(\"autodoc-skip-member\", skip_member)\n\n\ndef skip_member(app, what, name, obj, skip, options):\n    \"\"\"Ignore ugly auto-generated doc strings from namedtuple.\"\"\"\n    doc = getattr(obj, \"__doc__\", \"\") or \"\"  # Handle when __doc__ is missing on None\n    is_namedtuple_docstring = bool(re.fullmatch(\"Alias for field number [0-9]+\", doc))\n    return is_namedtuple_docstring or skip\n"
  },
  {
    "path": "docs/index.rst",
    "content": "Simplified Python gRPC Interceptors\n===================================\n\n.. toctree::\n   :hidden:\n   :maxdepth: 1\n\n   reference\n   license\n\n.. contents::\n\nThe primary aim of this project is to make Python gRPC interceptors simple.\nThe Python ``grpc`` package provides service interceptors, but they're a bit hard to\nuse because of their flexibility. The ``grpc`` interceptors don't have direct access\nto the request and response objects, or the service context. Access to these are often\ndesired, to be able to log data in the request or response, or set status codes on the\ncontext.\n\nThe secondary aim of this project is to keep the code small and simple. Code you can\nread through and understand quickly gives you confidence and helps debug issues. When\nyou install this package, you also don't want a bunch of other packages that might\ncause conflicts within your project. Too many dependencies slow down installation\nas well as runtime (fresh imports take time). Hence, a goal of this project is to keep\ndependencies to a minimum. The only core dependency is the ``grpc`` package, and the\n``testing`` extra includes ``protobuf`` as well.\n\nThe ``grpc_interceptor`` package provides the following:\n\n* A ``ServerInterceptor`` base class, to make it easy to define your own server-side\n  interceptors. Do not confuse this with the ``grpc.ServerInterceptor`` class.\n* An ``AsyncServerInterceptor`` base class, which is the analogy for async server-side\n  interceptors.\n* An ``ExceptionToStatusInterceptor`` interceptor, so your service can raise exceptions\n  that set the gRPC status code correctly (rather than the default of every exception\n  resulting in an ``UNKNOWN`` status code). This is something for which pretty much any\n  service will have a use.\n* An ``AsyncExceptionToStatusInterceptor`` interceptor, which is the analogy for async\n  ``ExceptionToStatusInterceptor``.\n* A ``ClientInterceptor`` base class, to make it easy to define your own client-side interceptors.\n  Do not confuse this with the ``grpc.ClientInterceptor`` class. (Note, there is currently no\n  async analogy to ``ClientInterceptor``, though contributions are welcome.)\n* An optional testing framework. If you're writing your own interceptors, this is useful.\n  If you're just using ``ExceptionToStatusInterceptor`` then you don't need this.\n\nInstallation\n------------\n\nTo install just the interceptors:\n\n.. code-block:: console\n\n   $ pip install grpc-interceptor\n\nTo also install the testing framework:\n\n.. code-block:: console\n\n   $ pip install grpc-interceptor[testing]\n\nUsage\n-----\n\nServer Interceptors\n^^^^^^^^^^^^^^^^^^^\n\nTo define your own server interceptor (we can use a simplified version of\n``ExceptionToStatusInterceptor`` as an example):\n\n.. code-block:: python\n\n   from grpc_interceptor import ServerInterceptor\n   from grpc_interceptor.exceptions import GrpcException\n\n   class ExceptionToStatusInterceptor(ServerInterceptor):\n\n       def intercept(\n           self,\n           method: Callable,\n           request_or_iterator: Any,\n           context: grpc.ServicerContext,\n           method_name: str,\n       ) -> Any:\n           \"\"\"Override this method to implement a custom interceptor.\n\n            You should call method(request_or_iterator, context) to invoke the\n            next handler (either the RPC method implementation, or the\n            next interceptor in the list).\n\n            Args:\n                method: The next interceptor, or method implementation.\n                request_or_iterator: The RPC request, as a protobuf message.\n                context: The ServicerContext pass by gRPC to the service.\n                method_name: A string of the form\n                    \"/protobuf.package.Service/Method\"\n\n            Returns:\n                This should generally return the result of\n                method(request_or_iterator, context), which is typically the RPC\n                method response, as a protobuf message. The interceptor\n                is free to modify this in some way, however.\n            \"\"\"\n           try:\n               return method(request_or_iterator, context)\n           except GrpcException as e:\n               context.set_code(e.status_code)\n               context.set_details(e.details)\n               raise\n\nThen inject your interceptor when you create the ``grpc`` server:\n\n.. code-block:: python\n\n   interceptors = [ExceptionToStatusInterceptor()]\n   server = grpc.server(\n       futures.ThreadPoolExecutor(max_workers=10),\n       interceptors=interceptors\n   )\n\nTo use ``ExceptionToStatusInterceptor``:\n\n.. code-block:: python\n\n   from grpc_interceptor.exceptions import NotFound\n\n   class MyService(my_pb2_grpc.MyServiceServicer):\n       def MyRpcMethod(\n           self, request: MyRequest, context: grpc.ServicerContext\n       ) -> MyResponse:\n           thing = lookup_thing()\n           if not thing:\n               raise NotFound(\"Sorry, your thing is missing\")\n           ...\n\nThis results in the gRPC status status code being set to ``NOT_FOUND``,\nand the details ``\"Sorry, your thing is missing\"``. This saves you the hassle of\ncatching exceptions in your service handler, or passing the context down into\nhelper functions so they can call ``context.abort`` or ``context.set_code``. It allows\nthe more Pythonic approach of just raising an exception from anywhere in the code,\nand having it be handled automatically.\n\nServer Streaming Interceptors\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nThe above example shows how to write an interceptor for a unary-unary RPC. Server\nstreaming RPCs need to be handled a little differently because ``method(request, context)``\nwill return a generator. Hence, the code won't actually run until you iterate over it.\nHence, if we were to continue the example of catching exceptions from RPCs, we would\nneed to do something like this:\n\n.. code-block:: python\n\n    class ExceptionToStatusInterceptor(ServerInterceptor):\n\n        def intercept(\n            self,\n            method: Callable,\n            request: Any,\n            context: grpc.ServicerContext,\n            method_name: str,\n        ) -> Any:\n            try:\n                for response in method(request, context):\n                    yield response\n            except GrpcException as e:\n                context.set_code(e.status_code)\n                context.set_details(e.details)\n                raise\n\nHowever, this will *only* work for server streaming RPCs. In order to work with both\nunary and streaming RPCs, you'll need to handle the unary case and streaming case\nseparately, like this:\n\n.. code-block:: python\n\n    class ExceptionToStatusInterceptor(ServerInterceptor):\n\n        def intercept(self, method, request, context, method_name):\n            # Call the RPC. It could be either unary or streaming\n            try:\n                response_or_iterator = method(request, context)\n            except GrpcException as e:\n                # If it was unary, then any exception raised would be caught\n                # immediately, so handle it here.\n                context.set_code(e.status_code)\n                context.set_details(e.details)\n                raise\n            # Check if it's streaming\n            if hasattr(response_or_iterator, \"__iter__\"):\n                # Now we know it's a server streaming RPC, so the actual RPC method\n                # hasn't run yet. Delegate to a helper to iterate over it so it runs.\n                # The helper needs to re-yield the responses, and we need to return\n                # the generator that produces.\n                return self._intercept_streaming(response_or_iterator)\n            else:\n                # For unary cases, we are done, so just return the response.\n                return response_or_iterator\n\n        def _intercept_streaming(self, iterator):\n            try:\n                for resp in iterator:\n                    yield resp\n            except GrpcException as e:\n                context.set_code(e.status_code)\n                context.set_details(e.details)\n                raise\n\nAsync Server Interceptors\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nAsync interceptors are similar to sync ones, but there are two things of which you need to\nbe aware.\n\nFirst, async server streaming RPCs that are implemented with ``async def + yield``\ncannot be awaited. When you call such a method, you get back an ``async_generator``.\nThis is not ``await``-able (though you can ``async for`` loop over it). This is\ncontrary to a unary RPC is implemented with ``async def + return``. That results in a\ncoroutine when called, which you *can* ``await``.\n\nAll this is to say that you mustn't await ``method(request, context)`` in an async\ninterceptor immediately. First, check if it's an ``async_generator``. You can do this\nby checking for the presence of the ``__aiter__`` attribute.\n\nHere's an async version of our running ``ExceptionToStatusInterceptor`` example:\n\n.. code-block:: python\n\n    from grpc_interceptor.exceptions import GrpcException\n    from grpc_interceptor.server import AsyncServerInterceptor\n\n    class AsyncExceptionToStatusInterceptor(AsyncServerInterceptor):\n\n        async def intercept(\n            self,\n            method: Callable,\n            request_or_iterator: Any,\n            context: grpc.ServicerContext,\n            method_name: str,\n        ) -> Any:\n            try:\n                response_or_iterator = method(request_or_iterator, context)\n                if not hasattr(response_or_iterator, \"__aiter__\"):\n                    # Unary, just await and return the response\n                    return await response_or_iterator\n            except GrpcException as e:\n                await context.set_code(e.status_code)\n                await context.set_details(e.details)\n                raise\n\n            # Server streaming responses, delegate to an async generator helper.\n            # Note that we do NOT await this.\n            return self._intercept_streaming(response_or_iterator, context)\n\n        async def _intercept_streaming(self, iterator, context):\n            try:\n                async for r in iterator:\n                    yield r\n            except GrpcException as e:\n                await context.set_code(e.status_code)\n                await context.set_details(e.details)\n                raise\n\nThe second thing you must be aware of with async RPCs, is that an\n`alternate streaming API <https://github.com/lidizheng/proposal/blob/grpc-python-async-api/L58-python-async-api.md#new-streaming-api---reader--writer-api>`_\nwas added. With this API, instead of writing a server streaming RPC with\n``async def + yield``, you write it as ``async def + return``, but it returns ``None``.\nThe way it streams responses is by calling ``await context.write(...)`` for each response\nit streams. Similarly, client streaming can be achieved by calling\n``await context.read()`` instead of iterating over the request object.\n\nIf you must support RPC services written using this new API, then you must be aware that\na server streaming RPC could return ``None``. In that case it will not be an\n``async_generator`` even though it's streaming. You will also need your own solution to\nget access to the streaming response objects. For example, you could wrap the\n``context`` object that you pass to ``method(request, context)``, so that you can\ncapture ``read`` and ``write`` calls.\n\nClient Interceptors\n^^^^^^^^^^^^^^^^^^^\n\nWe will use an invocation metadata injecting interceptor as an example of defining\na client interceptor:\n\n.. code-block:: python\n\n    from grpc_interceptor import ClientCallDetails, ClientInterceptor\n\n    class MetadataClientInterceptor(ClientInterceptor):\n\n        def intercept(\n            self,\n            method: Callable,\n            request_or_iterator: Any,\n            call_details: grpc.ClientCallDetails,\n        ):\n            \"\"\"Override this method to implement a custom interceptor.\n\n            This method is called for all unary and streaming RPCs. The interceptor\n            implementation should call `method` using a `grpc.ClientCallDetails` and the\n            `request_or_iterator` object as parameters. The `request_or_iterator`\n            parameter may be type checked to determine if this is a singluar request\n            for unary RPCs or an iterator for client-streaming or client-server streaming\n            RPCs.\n\n            Args:\n                method: A function that proceeds with the invocation by executing the next\n                    interceptor in the chain or invoking the actual RPC on the underlying\n                    channel.\n                request_or_iterator: RPC request message or iterator of request messages\n                    for streaming requests.\n                call_details: Describes an RPC to be invoked.\n\n            Returns:\n                The type of the return should match the type of the return value received\n                by calling `method`. This is an object that is both a\n                `Call <https://grpc.github.io/grpc/python/grpc.html#grpc.Call>`_ for the\n                RPC and a `Future <https://grpc.github.io/grpc/python/grpc.html#grpc.Future>`_.\n\n                The actual result from the RPC can be got by calling `.result()` on the\n                value returned from `method`.\n            \"\"\"\n            new_details = ClientCallDetails(\n                call_details.method,\n                call_details.timeout,\n                [(\"authorization\", \"Bearer mysecrettoken\")],\n                call_details.credentials,\n                call_details.wait_for_ready,\n                call_details.compression,\n            )\n\n            return method(request_or_iterator, new_details)\n\nNow inject your interceptor when you create the ``grpc`` channel:\n\n.. code-block:: python\n\n    interceptors = [MetadataClientInterceptor()]\n    with grpc.insecure_channel(\"grpc-server:50051\") as channel:\n        channel = grpc.intercept_channel(channel, *interceptors)\n        ...\n\nClient interceptors can also be used to\n`retry RPCs <https://github.com/d5h-foss/grpc-interceptor/blob/4b6bb6a59aae97aec058c0d4072dd19de8f408bc/tests/test_client.py#L39-L56>`_\nthat fail due to specific errors, or a host of other use cases. There are some basic\napproaches in\n`the tests <https://github.com/d5h-foss/grpc-interceptor/blob/master/tests/test_client.py>`_\nto get you started.\n\nNote: The ``method`` in a client interceptor is a ``continuation`` as described in the\n`client interceptor section of the gRPC docs <https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryClientInterceptor.intercept_unary_unary>`_.\nWhen you invoke the continuation, you get a future back, which resolves to either the\nresult, or exception. This is different than invoking a client stub, which returns the\nresult directly. If the interceptor needs the value returned by the call, or to catch\nexceptions, then you'll need to do ``future = method(request_or_iterator, call_details)``,\nfollowed by ``future.result()``. Check out the tests for\n`examples <https://github.com/d5h-foss/grpc-interceptor/blob/4b6bb6a59aae97aec058c0d4072dd19de8f408bc/tests/test_client.py#L39-L56>`_.\n\nTesting\n-------\n\nThe testing framework provides an actual gRPC service and client, which you can inject\ninterceptors into. This allows end-to-end testing, rather than mocking things out (such\nas the context). This can catch interactions between your interceptors and the gRPC\nframework, and also allows chaining interceptors.\n\nThe crux of the testing framework is the ``dummy_client`` context manager. It provides\na client to a gRPC service, which by defaults echos the ``input`` field of the request\nto the ``output`` field of the response.\n\nYou can also provide a ``special_cases`` dict which tells the service to call arbitrary\nfunctions when the input matches a key in the dict. This allows you to test things like\nexceptions being thrown.\n\nHere's an example (again using ``ExceptionToStatusInterceptor``):\n\n.. code-block:: python\n\n   from grpc_interceptor import ExceptionToStatusInterceptor\n   from grpc_interceptor.exceptions import NotFound\n   from grpc_interceptor.testing import dummy_client, DummyRequest, raises\n\n   def test_exception():\n       special_cases = {\"error\": raises(NotFound())}\n       interceptors = [ExceptionToStatusInterceptor()]\n       with dummy_client(special_cases=special_cases, interceptors=interceptors) as client:\n           # Test a happy path first\n           assert client.Execute(DummyRequest(input=\"foo\")).output == \"foo\"\n           # And now a special case\n           with pytest.raises(grpc.RpcError) as e:\n               client.Execute(DummyRequest(input=\"error\"))\n           assert e.value.code() == grpc.StatusCode.NOT_FOUND\n\nLimitations\n-----------\n\nKnown limitations:\n\n* Async client interceptors are not implemented.\n* The ``read`` / ``write`` API for async streaming technically works,\n  but you'll need to roll your own solution to get access to streaming\n  request and response objects.\n\nContributions or requests are welcome for any limitations you may find.\n"
  },
  {
    "path": "docs/license.rst",
    "content": "License\n=======\n\n.. include:: ../LICENSE\n"
  },
  {
    "path": "docs/reference.rst",
    "content": "Reference\n=========\n\n.. contents::\n    :local:\n    :backlinks: none\n\ngrpc_interceptor\n---------------------\n\n.. automodule:: grpc_interceptor\n   :members:\n\ngrpc_interceptor.exceptions\n------------------------------------\n\n.. automodule:: grpc_interceptor.exceptions\n   :members:\n\ngrpc_interceptor.testing\n------------------------------------\n\n.. automodule:: grpc_interceptor.testing\n   :members:\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinx==3.1.2\ndocutils==0.16\n"
  },
  {
    "path": "mypy.ini",
    "content": "[mypy]\n\n[mypy-nox.*,grpc,pytest]\nignore_missing_imports = True\n"
  },
  {
    "path": "noxfile.py",
    "content": "\"\"\"Nox sessions.\"\"\"\n\nfrom contextlib import contextmanager\nfrom pathlib import Path\nimport tempfile\nfrom typing import List\nfrom uuid import uuid4\n\nimport nox\nimport toml\n\n\nnox.options.sessions = \"lint\", \"mypy\", \"tests\", \"xdoctest\", \"mindeps\"\nPY_VERSIONS = [\"3.11\", \"3.10\", \"3.9\", \"3.8\", \"3.7\"]\nPY_LATEST = \"3.11\"\n\n\n@nox.session(python=PY_VERSIONS)\ndef tests(session):\n    \"\"\"Run the test suite.\"\"\"\n    args = session.posargs or [\n        \"--cov\",\n        \"--cov-report\",\n        \"term-missing\",\n        \"--cov-report\",\n        \"xml\",\n        \"-ra\",\n        \"-vv\",\n    ]\n    session.run(\"poetry\", \"install\", \"--no-dev\", external=True)\n    install_with_constraints(\n        session,\n        \"coverage[toml]\",\n        \"grpcio-tools\",\n        \"pytest\",\n        \"pytest-asyncio\",\n        \"pytest-cov\",\n    )\n    session.run(\"pytest\", *args)\n\n\n@nox.session(python=PY_VERSIONS)\ndef xdoctest(session) -> None:\n    \"\"\"Run examples with xdoctest.\"\"\"\n    args = session.posargs or [\"all\"]\n    session.run(\"poetry\", \"install\", \"--no-dev\", external=True)\n    install_with_constraints(session, \"xdoctest\")\n    session.run(\"python\", \"-m\", \"xdoctest\", \"grpc_interceptor\", *args)\n\n\n@nox.session(python=PY_LATEST)\ndef docs(session):\n    \"\"\"Build the documentation.\"\"\"\n    session.run(\"poetry\", \"install\", \"--no-dev\", \"-E\", \"testing\", external=True)\n    install_with_constraints(session, \"sphinx\")\n    session.run(\"sphinx-build\", \"docs\", \"docs/_build\")\n\n\nSOURCE_CODE = [\"src\", \"tests\", \"noxfile.py\", \"docs/conf.py\"]\n\n\n@nox.session(python=PY_LATEST)\ndef black(session):\n    \"\"\"Run black code formatter.\"\"\"\n    args = session.posargs or SOURCE_CODE\n    install_with_constraints(session, \"black\")\n    session.run(\"black\", *args)\n\n\n@nox.session(python=PY_LATEST)\ndef lint(session):\n    \"\"\"Lint using flake8.\"\"\"\n    args = session.posargs or SOURCE_CODE\n    install_with_constraints(\n        session,\n        \"flake8\",\n        \"flake8-bandit\",\n        \"flake8-bugbear\",\n        \"flake8-docstrings\",\n        \"flake8-import-order\",\n    )\n    session.run(\"flake8\", *args)\n\n\n@nox.session(python=PY_VERSIONS)\ndef mypy(session):\n    \"\"\"Type-check using mypy.\"\"\"\n    args = session.posargs or SOURCE_CODE\n    install_with_constraints(session, \"mypy\")\n    session.run(\"mypy\", \"--install-types\", \"--non-interactive\", *args)\n    session.run(\"mypy\", *args)\n\n\n@nox.session(python=PY_LATEST)\ndef safety(session):\n    \"\"\"Scan dependencies for insecure packages.\"\"\"\n    with _temp_file() as requirements:\n        session.run(\n            \"poetry\",\n            \"export\",\n            \"--dev\",\n            \"--format=requirements.txt\",\n            \"--without-hashes\",\n            f\"--output={requirements}\",\n            external=True,\n        )\n        install_with_constraints(session, \"safety\")\n        session.run(\"safety\", \"check\", f\"--file={requirements}\", \"--full-report\")\n\n\n@nox.session(python=\"3.7\")\ndef mindeps(session):\n    \"\"\"Run test with minimum versions of dependencies.\"\"\"\n    deps = _parse_minimum_dependency_versions()\n    session.install(*deps)\n    session.run(\"pytest\", env={\"PYTHONPATH\": \"src\"})\n\n\ndef install_with_constraints(session, *args, **kwargs):\n    \"\"\"Install packages constrained by Poetry's lock file.\"\"\"\n    with _temp_file() as requirements:\n        session.run(\n            \"poetry\",\n            \"export\",\n            \"--dev\",\n            \"--format=requirements.txt\",\n            f\"--output={requirements}\",\n            \"--without-hashes\",\n            external=True,\n        )\n        session.install(f\"--constraint={requirements}\", *args, **kwargs)\n\n\n@contextmanager\ndef _temp_file():\n    # NamedTemporaryFile doesn't work on Windows.\n    path = Path(tempfile.gettempdir()) / str(uuid4())\n    try:\n        yield path\n    finally:\n        try:\n            path.unlink()\n        except FileNotFoundError:\n            pass\n\n\ndef _parse_minimum_dependency_versions() -> List[str]:\n    pyproj = toml.load(\"pyproject.toml\")\n    dependencies = pyproj[\"tool\"][\"poetry\"][\"dependencies\"]\n    dev_dependencies = pyproj[\"tool\"][\"poetry\"][\"dev-dependencies\"]\n    min_deps = []\n\n    for deps in (dependencies, dev_dependencies):\n        for dep, constraint in deps.items():\n            if dep == \"python\":\n                continue\n\n            if not isinstance(constraint, str):\n                # Don't install deps with python contraints, because they're always for\n                # newer versions on python.\n                if \"python\" in constraint:\n                    continue\n                constraint = constraint[\"version\"]\n\n            if constraint.startswith(\"^\") or constraint.startswith(\"~\"):\n                version = constraint[1:]\n            elif constraint.startswith(\">=\"):\n                version = constraint[2:]\n            else:\n                version = constraint\n\n            min_deps.append(f\"{dep}=={version}\")\n\n    return min_deps\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.poetry]\nname = \"grpc-interceptor\"\nversion = \"0.15.4\"\ndescription = \"Simplifies gRPC interceptors\"\nlicense = \"MIT\"\nreadme = \"README.md\"\nhomepage = \"https://github.com/d5h-foss/grpc-interceptor\"\nrepository = \"https://github.com/d5h-foss/grpc-interceptor\"\nkeywords = [\"grpc\", \"interceptor\"]\nauthors = [\"Dan Hipschman\"]\ndocumentation = \"https://grpc-interceptor.readthedocs.io\"\n\n[tool.poetry.dependencies]\npython = \"^3.7\"\n# Earliest version that supports aio w/o deadlock issue: https://github.com/grpc/grpc/pull/23945\ngrpcio = \"^1.49.1\"\n# https://github.com/protocolbuffers/protobuf/issues/10075\nprotobuf = {version = \">=4.21.9\", optional = true}\n\n[tool.poetry.extras]\ntesting = [\"protobuf\"]\n\n[tool.poetry.dev-dependencies]\npytest = \"^6.1.0\"\ngrpcio-tools = \"^1.49.1\"\ncoverage = {extras = [\"toml\"], version = \"^7.2.3\"}\npytest-cov = \"^2.10.0\"\nblack = \"^23.3.0\"\nflake8 = \"^5.0.0\"\nflake8-bandit = \"^4.1.1\"  # https://github.com/tylerwince/flake8-bandit/issues/21\nflake8-bugbear = \"^20.1.4\"\nflake8-import-order = \"^0.18.1\"\nsafety = \"^1.9.0\"\nmypy = \"^1.2.0\"\nmypy-protobuf = \"^1.23\"\nflake8-docstrings = \"^1.5.0\"\nsphinx = \"^3.1.2\"\nxdoctest = \"^0.13.0\"\npytest-asyncio = {version = \"^0.19.0\", python = \">=3.7\"}\n\n[tool.coverage.paths]\nsource = [\"src\"]\n\n[tool.coverage.run]\nbranch = true\nsource = [\"grpc_interceptor\"]\nomit = [\"*_pb2.py\", \"*_pb2_grpc.py\"]\n\n[tool.coverage.report]\nshow_missing = true\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n"
  },
  {
    "path": "src/grpc_interceptor/__init__.py",
    "content": "\"\"\"Simplified Python gRPC interceptors.\"\"\"\n\nfrom grpc_interceptor.client import ClientCallDetails, ClientInterceptor\nfrom grpc_interceptor.exception_to_status import (\n    AsyncExceptionToStatusInterceptor,\n    ExceptionToStatusInterceptor,\n)\nfrom grpc_interceptor.server import (\n    AsyncServerInterceptor,\n    MethodName,\n    parse_method_name,\n    ServerInterceptor,\n)\n\n\n__all__ = [\n    \"AsyncExceptionToStatusInterceptor\",\n    \"AsyncServerInterceptor\",\n    \"ClientCallDetails\",\n    \"ClientInterceptor\",\n    \"ExceptionToStatusInterceptor\",\n    \"MethodName\",\n    \"parse_method_name\",\n    \"ServerInterceptor\",\n]\n"
  },
  {
    "path": "src/grpc_interceptor/client.py",
    "content": "\"\"\"Base class for client-side interceptors.\"\"\"\n\nimport abc\nfrom typing import Any, Callable, Iterator, NamedTuple, Optional, Sequence, Tuple, Union\n\nimport grpc\n\n\nclass _ClientCallDetailsFields(NamedTuple):\n    method: str\n    timeout: Optional[float]\n    metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]]\n    credentials: Optional[grpc.CallCredentials]\n    wait_for_ready: Optional[bool]\n    compression: Any  # Type added in grpcio 1.23.0\n\n\nclass ClientCallDetails(_ClientCallDetailsFields, grpc.ClientCallDetails):\n    \"\"\"Describes an RPC to be invoked.\n\n    See https://grpc.github.io/grpc/python/grpc.html#grpc.ClientCallDetails\n    \"\"\"\n\n    pass\n\n\nclass ClientInterceptorReturnType(grpc.Call, grpc.Future):\n    \"\"\"Return type for the ClientInterceptor.intercept method.\"\"\"\n\n    pass\n\n\nclass ClientInterceptor(\n    grpc.UnaryUnaryClientInterceptor,\n    grpc.UnaryStreamClientInterceptor,\n    grpc.StreamUnaryClientInterceptor,\n    grpc.StreamStreamClientInterceptor,\n    metaclass=abc.ABCMeta,\n):\n    \"\"\"Base class for client-side interceptors.\n\n    To implement an interceptor, subclass this class and override the intercept method.\n    \"\"\"\n\n    @abc.abstractmethod\n    def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        call_details: grpc.ClientCallDetails,\n    ) -> ClientInterceptorReturnType:\n        \"\"\"Override this method to implement a custom interceptor.\n\n        This method is called for all unary and streaming RPCs. The interceptor\n        implementation should call `method` using a `grpc.ClientCallDetails` and the\n        `request_or_iterator` object as parameters. The `request_or_iterator`\n        parameter may be type checked to determine if this is a singluar request\n        for unary RPCs or an iterator for client-streaming or client-server streaming\n        RPCs.\n\n        Args:\n            method: A function that proceeds with the invocation by executing the next\n                interceptor in the chain or invoking the actual RPC on the underlying\n                channel.\n            request_or_iterator: RPC request message or iterator of request messages\n                for streaming requests.\n            call_details: Describes an RPC to be invoked.\n\n        Returns:\n            The type of the return should match the type of the return value received\n            by calling `method`. This is an object that is both a\n            `Call <https://grpc.github.io/grpc/python/grpc.html#grpc.Call>`_ for the\n            RPC and a\n            `Future <https://grpc.github.io/grpc/python/grpc.html#grpc.Future>`_.\n\n            The actual result from the RPC can be got by calling `.result()` on the\n            value returned from `method`.\n        \"\"\"\n        return method(request_or_iterator, call_details)  # pragma: no cover\n\n    def intercept_unary_unary(\n        self,\n        continuation: Callable,\n        call_details: grpc.ClientCallDetails,\n        request: Any,\n    ):\n        \"\"\"Implementation of grpc.UnaryUnaryClientInterceptor.\n\n        This is not part of the grpc_interceptor.ClientInterceptor API, but must have\n        a public name. Do not override it, unless you know what you're doing.\n        \"\"\"\n        return self.intercept(_swap_args(continuation), request, call_details)\n\n    def intercept_unary_stream(\n        self,\n        continuation: Callable,\n        call_details: grpc.ClientCallDetails,\n        request: Any,\n    ):\n        \"\"\"Implementation of grpc.UnaryStreamClientInterceptor.\n\n        This is not part of the grpc_interceptor.ClientInterceptor API, but must have\n        a public name. Do not override it, unless you know what you're doing.\n        \"\"\"\n        return self.intercept(_swap_args(continuation), request, call_details)\n\n    def intercept_stream_unary(\n        self,\n        continuation: Callable,\n        call_details: grpc.ClientCallDetails,\n        request_iterator: Iterator[Any],\n    ):\n        \"\"\"Implementation of grpc.StreamUnaryClientInterceptor.\n\n        This is not part of the grpc_interceptor.ClientInterceptor API, but must have\n        a public name. Do not override it, unless you know what you're doing.\n        \"\"\"\n        return self.intercept(_swap_args(continuation), request_iterator, call_details)\n\n    def intercept_stream_stream(\n        self,\n        continuation: Callable,\n        call_details: grpc.ClientCallDetails,\n        request_iterator: Iterator[Any],\n    ):\n        \"\"\"Implementation of grpc.StreamStreamClientInterceptor.\n\n        This is not part of the grpc_interceptor.ClientInterceptor API, but must have\n        a public name. Do not override it, unless you know what you're doing.\n        \"\"\"\n        return self.intercept(_swap_args(continuation), request_iterator, call_details)\n\n\ndef _swap_args(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:\n    def new_fn(x, y):\n        return fn(y, x)\n\n    return new_fn\n"
  },
  {
    "path": "src/grpc_interceptor/exception_to_status.py",
    "content": "\"\"\"ExceptionToStatusInterceptor catches GrpcException and sets the gRPC context.\"\"\"\n\n# TODO: use asynccontextmanager\nfrom contextlib import contextmanager\nfrom typing import (\n    Any,\n    AsyncGenerator,\n    AsyncIterable,\n    Callable,\n    Generator,\n    Iterable,\n    Iterator,\n    NoReturn,\n    Optional,\n)\n\nimport grpc\nfrom grpc import aio as grpc_aio\n\nfrom grpc_interceptor.exceptions import GrpcException\nfrom grpc_interceptor.server import AsyncServerInterceptor, ServerInterceptor\n\n\nclass ExceptionToStatusInterceptor(ServerInterceptor):\n    \"\"\"An interceptor that catches exceptions and sets the RPC status and details.\n\n    ExceptionToStatusInterceptor will catch any subclass of GrpcException and set the\n    status code and details on the gRPC context. You can also extend this and override\n    the handle_exception method to catch other types of exceptions, and handle them in\n    different ways. E.g., you can catch and handle exceptions that don't derive from\n    GrpcException. Or you can set rich error statuses with context.abort_with_status().\n\n    Args:\n        status_on_unknown_exception: Specify what to do if an exception which is\n            not a subclass of GrpcException is raised. If None, do nothing (by\n            default, grpc will set the status to UNKNOWN). If not None, then the\n            status code will be set to this value if `context.abort` hasn't been called\n            earlier. It must not be OK. The details will be set to the value of repr(e),\n            where e is the exception. In any case, the exception will be propagated.\n\n    Raises:\n        ValueError: If status_code is OK.\n    \"\"\"\n\n    def __init__(self, status_on_unknown_exception: Optional[grpc.StatusCode] = None):\n        if status_on_unknown_exception == grpc.StatusCode.OK:\n            raise ValueError(\"The status code for unknown exceptions cannot be OK\")\n\n        self._status_on_unknown_exception = status_on_unknown_exception\n\n    def _generate_responses(\n        self,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n        response_iterator: Iterable,\n    ) -> Generator[Any, None, None]:\n        \"\"\"Yield all the responses, but check for errors along the way.\"\"\"\n        with self._handle_exception(request_or_iterator, context, method_name):\n            yield from response_iterator\n\n    @contextmanager\n    def _handle_exception(\n        self, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str\n    ) -> Iterator[None]:\n        try:\n            yield\n        except Exception as ex:\n            self.handle_exception(ex, request_or_iterator, context, method_name)\n\n    def handle_exception(\n        self,\n        ex: Exception,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> NoReturn:\n        \"\"\"Override this if extending ExceptionToStatusInterceptor.\n\n        This will get called when an exception is raised while handling the RPC.\n\n        Args:\n            ex: The exception that was raised.\n            request_or_iterator: The RPC request, as a protobuf message if it is a\n                unary request, or an iterator of protobuf messages if it is a streaming\n                request.\n            context: The servicer context. You probably want to call context.abort(...)\n            method_name: The name of the RPC being called.\n\n        Raises:\n            This method must raise and cannot return, as in general there's no\n            meaningful RPC response to return if an exception has occurred. You can\n            raise the original exception, ex, or something else.\n        \"\"\"\n        if isinstance(ex, GrpcException):\n            context.abort(ex.status_code, ex.details)\n        elif not context.code():\n            if self._status_on_unknown_exception is not None:\n                context.abort(self._status_on_unknown_exception, repr(ex))\n        raise ex\n\n    def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> Any:\n        \"\"\"Do not call this directly; use the interceptor kwarg on grpc.server().\"\"\"\n        with self._handle_exception(request_or_iterator, context, method_name):\n            response_or_iterator = method(request_or_iterator, context)\n\n        if isinstance(response_or_iterator, Iterable):\n            # multiple responses; return a generator\n            return self._generate_responses(\n                request_or_iterator, context, method_name, response_or_iterator\n            )\n        else:\n            # return a single response\n            return response_or_iterator\n\n\nclass AsyncExceptionToStatusInterceptor(AsyncServerInterceptor):\n    \"\"\"An interceptor that catches exceptions and sets the RPC status and details.\n\n    This is the async analogy to ExceptionToStatusInterceptor. Please see that class'\n    documentation for more information.\n    \"\"\"\n\n    def __init__(self, status_on_unknown_exception: Optional[grpc.StatusCode] = None):\n        if status_on_unknown_exception == grpc.StatusCode.OK:\n            raise ValueError(\"The status code for unknown exceptions cannot be OK\")\n\n        self._status_on_unknown_exception = status_on_unknown_exception\n\n    async def _generate_responses(\n        self,\n        request_or_iterator: Any,\n        context: grpc_aio.ServicerContext,\n        method_name: str,\n        response_iterator: AsyncIterable,\n    ) -> AsyncGenerator[Any, None]:\n        \"\"\"Yield all the responses, but check for errors along the way.\"\"\"\n        try:\n            async for r in response_iterator:\n                yield r\n        except Exception as ex:\n            await self.handle_exception(ex, request_or_iterator, context, method_name)\n\n    async def handle_exception(\n        self,\n        ex: Exception,\n        request_or_iterator: Any,\n        context: grpc_aio.ServicerContext,\n        method_name: str,\n    ) -> NoReturn:\n        \"\"\"Override this if extending ExceptionToStatusInterceptor.\n\n        This will get called when an exception is raised while handling the RPC.\n\n        Args:\n            ex: The exception that was raised.\n            request_or_iterator: The RPC request, as a protobuf message if it is a\n                unary request, or an iterator of protobuf messages if it is a streaming\n                request.\n            context: The servicer context. You probably want to call context.abort(...)\n            method_name: The name of the RPC being called.\n\n        Raises:\n            This method must raise and cannot return, as in general there's no\n            meaningful RPC response to return if an exception has occurred. You can\n            raise the original exception, ex, or something else.\n        \"\"\"\n        if isinstance(ex, GrpcException):\n            await context.abort(ex.status_code, ex.details)\n        elif not context.code():\n            if self._status_on_unknown_exception is not None:\n                await context.abort(self._status_on_unknown_exception, repr(ex))\n        raise ex\n\n    async def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc_aio.ServicerContext,\n        method_name: str,\n    ) -> Any:\n        \"\"\"Do not call this directly; use the interceptor kwarg on grpc.server().\"\"\"\n        try:\n            response_or_iterator = method(request_or_iterator, context)\n            if not hasattr(response_or_iterator, \"__aiter__\"):\n                return await response_or_iterator\n        except Exception as ex:\n            await self.handle_exception(ex, request_or_iterator, context, method_name)\n\n        return self._generate_responses(\n            request_or_iterator, context, method_name, response_or_iterator\n        )\n"
  },
  {
    "path": "src/grpc_interceptor/exceptions.py",
    "content": "\"\"\"Exceptions for ExceptionToStatusInterceptor.\n\nSee https://grpc.github.io/grpc/core/md_doc_statuscodes.html for the source of truth\non status code meanings.\n\"\"\"\n\nfrom typing import Optional\n\nfrom grpc import StatusCode\n\n\nclass GrpcException(Exception):\n    \"\"\"Base class for gRPC exceptions.\n\n    Generally you would not use this class directly, but rather use a subclass\n    representing one of the standard gRPC status codes (see:\n    https://grpc.github.io/grpc/core/md_doc_statuscodes.html for the official list).\n\n    Attributes:\n        status_code: A grpc.StatusCode other than OK. The only use case for this\n            is if gRPC adds a new status code that isn't represented by one of the\n            subclasses of GrpcException. Must not be OK, because gRPC will not\n            raise an RpcError to the client if the status code is OK.\n        details: A string with additional informantion about the error.\n    Args:\n        details: If not None, specifies a custom error message.\n        status_code: If not None, sets the status code.\n\n    Raises:\n        ValueError: If status_code is OK.\n    \"\"\"\n\n    status_code: StatusCode = StatusCode.UNKNOWN\n    details: str = \"Unknown exception occurred\"\n\n    def __init__(\n        self, details: Optional[str] = None, status_code: Optional[StatusCode] = None\n    ):\n        if status_code is not None:\n            if status_code == StatusCode.OK:\n                raise ValueError(\"The status code for an exception cannot be OK\")\n            self.status_code = status_code\n        if details is not None:\n            self.details = details\n\n    def __repr__(self) -> str:\n        \"\"\"Show the status code and details.\n\n        Returns:\n            A string displaying the class name, status code, and details.\n        \"\"\"\n        clsname = self.__class__.__name__\n        sc = self.status_code.name\n        return f\"{clsname}(status_code={sc}, details={self.details!r})\"\n\n    @property\n    def status_string(self):\n        \"\"\"Return status_code as a string.\n\n        Returns:\n            The status code as a string.\n\n        Example:\n            >>> GrpcException(status_code=StatusCode.NOT_FOUND).status_string\n            'NOT_FOUND'\n        \"\"\"\n        return self.status_code.name\n\n\nclass Aborted(GrpcException):\n    \"\"\"The operation was aborted.\n\n    Typically this is due to a concurrency issue such as a sequencer check failure or\n    transaction abort. See the guidelines on other exceptions for deciding between\n    FAILED_PRECONDITION, ABORTED, and UNAVAILABLE.\n    \"\"\"\n\n    status_code = StatusCode.ABORTED\n    details = \"The operation was aborted\"\n\n\nclass AlreadyExists(GrpcException):\n    \"\"\"The entity that a client attempted to create already exists.\n\n    E.g., a file or directory that a client is trying to create already exists.\n    \"\"\"\n\n    status_code = StatusCode.ALREADY_EXISTS\n    details = \"The entity attempted to be created already exists\"\n\n\nclass Cancelled(GrpcException):\n    \"\"\"The operation was cancelled, typically by the caller.\"\"\"\n\n    status_code = StatusCode.CANCELLED\n    details = \"The operation was cancelled\"\n\n\nclass DataLoss(GrpcException):\n    \"\"\"Unrecoverable data loss or corruption.\"\"\"\n\n    status_code = StatusCode.DATA_LOSS\n    details = \"There was unrecoverable data loss or corruption\"\n\n\nclass DeadlineExceeded(GrpcException):\n    \"\"\"The deadline expired before the operation could complete.\n\n    For operations that change the state of the system, this error may be returned even\n    if the operation has completed successfully. For example, a successful response\n    from a server could have been delayed long.\n    \"\"\"\n\n    status_code = StatusCode.DEADLINE_EXCEEDED\n    details = \"Deadline expired before operation could complete\"\n\n\nclass FailedPrecondition(GrpcException):\n    \"\"\"The operation failed because the system is in an invalid state for execution.\n\n    For example, the directory to be deleted is non-empty, an rmdir operation is\n    applied to a non-directory, etc. Service implementors can use the following\n    guidelines to decide between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE:\n    (a) Use UNAVAILABLE if the client can retry just the failing call. (b) Use ABORTED\n    if the client should retry at a higher level (e.g., when a client-specified\n    test-and-set fails, indicating the client should restart a read-modify-write\n    sequence). (c) Use FAILED_PRECONDITION if the client should not retry until the\n    system state has been explicitly fixed. E.g., if an \"rmdir\" fails because the\n    directory is non-empty, FAILED_PRECONDITION should be returned since the client\n    should not retry unless the files are deleted from the directory.\n    \"\"\"\n\n    status_code = StatusCode.FAILED_PRECONDITION\n    details = (\n        \"The operation was rejected because the system is not\"\n        \" in a state required for execution\"\n    )\n\n\nclass InvalidArgument(GrpcException):\n    \"\"\"The client specified an invalid argument.\n\n    Note that this differs from FAILED_PRECONDITION. INVALID_ARGUMENT indicates\n    arguments that are problematic regardless of the state of the system (e.g., a\n    malformed file name).\n    \"\"\"\n\n    status_code = StatusCode.INVALID_ARGUMENT\n    details = \"The client specified an invalid argument\"\n\n\nclass Internal(GrpcException):\n    \"\"\"Internal errors.\n\n    This means that some invariants expected by the underlying system have been broken.\n    This error code is reserved for serious errors.\n    \"\"\"\n\n    status_code = StatusCode.INTERNAL\n    details = \"Internal error\"\n\n\nclass OutOfRange(GrpcException):\n    \"\"\"The operation was attempted past the valid range.\n\n    E.g., seeking or reading past end-of-file. Unlike INVALID_ARGUMENT, this error\n    indicates a problem that may be fixed if the system state changes. For example, a\n    32-bit file system will generate INVALID_ARGUMENT if asked to read at an offset\n    that is not in the range [0,2^32-1], but it will generate OUT_OF_RANGE if asked to\n    read from an offset past the current file size. There is a fair bit of overlap\n    between FAILED_PRECONDITION and OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the\n    more specific error) when it applies so that callers who are iterating through a\n    space can easily look for an OUT_OF_RANGE error to detect when they are done.\n    \"\"\"\n\n    status_code = StatusCode.OUT_OF_RANGE\n    details = \"The operation was attempted past the valid range\"\n\n\nclass NotFound(GrpcException):\n    \"\"\"Some requested entity (e.g., file or directory) was not found.\n\n    Note to server developers: if a request is denied for an entire class of users,\n    such as gradual feature rollout or undocumented whitelist, NOT_FOUND may be used.\n    If a request is denied for some users within a class of users, such as user-based\n    access control, PERMISSION_DENIED must be used.\n    \"\"\"\n\n    status_code = StatusCode.NOT_FOUND\n    details = \"The requested entity was not found\"\n\n\nclass PermissionDenied(GrpcException):\n    \"\"\"The caller does not have permission to execute the specified operation.\n\n    PERMISSION_DENIED must not be used for rejections caused by exhausting some\n    resource (use RESOURCE_EXHAUSTED instead for those errors). PERMISSION_DENIED\n    must not be used if the caller can not be identified (use UNAUTHENTICATED instead\n    for those errors). This error code does not imply the request is valid or the\n    requested entity exists or satisfies other pre-conditions.\n    \"\"\"\n\n    status_code = StatusCode.PERMISSION_DENIED\n    details = \"The caller does not have permission to execute the specified operation\"\n\n\nclass ResourceExhausted(GrpcException):\n    \"\"\"Some resource has been exhausted.\n\n    Perhaps a per-user quota, or perhaps the entire file system is out of space.\n    \"\"\"\n\n    status_code = StatusCode.RESOURCE_EXHAUSTED\n    details = \"A resource has been exhausted\"\n\n\nclass Unauthenticated(GrpcException):\n    \"\"\"The request does not have valid authentication credentials for the operation.\"\"\"\n\n    status_code = StatusCode.UNAUTHENTICATED\n    details = (\n        \"The request does not have valid authentication credentials for the operation\"\n    )\n\n\nclass Unavailable(GrpcException):\n    \"\"\"The service is currently unavailable.\n\n    This is most likely a transient condition, which can be corrected by retrying with\n    a backoff. Note that it is not always safe to retry non-idempotent operations.\n    \"\"\"\n\n    status_code = StatusCode.UNAVAILABLE\n    details = \"The service is currently unavailable\"\n\n\nclass Unimplemented(GrpcException):\n    \"\"\"The operation is not implemented or is not supported/enabled in this service.\"\"\"\n\n    status_code = StatusCode.UNIMPLEMENTED\n    details = (\n        \"The operation is not implemented or not supported/enabled in this service\"\n    )\n\n\nclass Unknown(GrpcException):\n    \"\"\"Unknown error.\n\n    For example, this error may be returned when a Status value received from another\n    address space belongs to an error space that is not known in this address space.\n    Also errors raised by APIs that do not return enough error information may be\n    converted to this error.\n    \"\"\"\n\n    pass\n"
  },
  {
    "path": "src/grpc_interceptor/py.typed",
    "content": ""
  },
  {
    "path": "src/grpc_interceptor/server.py",
    "content": "\"\"\"Base class for server-side interceptors.\"\"\"\n\nimport abc\nfrom asyncio import iscoroutine\nfrom typing import Any, Callable, Tuple\n\nimport grpc\nfrom grpc import aio as grpc_aio  # Needed for grpcio pre-1.33.2\n\n\nclass ServerInterceptor(grpc.ServerInterceptor, metaclass=abc.ABCMeta):\n    \"\"\"Base class for server-side interceptors.\n\n    To implement an interceptor, subclass this class and override the intercept method.\n    \"\"\"\n\n    @abc.abstractmethod\n    def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> Any:  # pragma: no cover\n        \"\"\"Override this method to implement a custom interceptor.\n\n        You should call method(request_or_iterator, context) to invoke the next handler\n        (either the RPC method implementation, or the next interceptor in the list).\n\n        Args:\n            method: Either the RPC method implementation, or the next interceptor in\n                the chain.\n            request_or_iterator: The RPC request, as a protobuf message if it is a\n                unary request, or an iterator of protobuf messages if it is a streaming\n                request.\n            context: The ServicerContext pass by gRPC to the service.\n            method_name: A string of the form \"/protobuf.package.Service/Method\"\n\n        Returns:\n            This should return the result of method(request, context), which\n            is typically the RPC method response, as a protobuf message, or an\n            iterator of protobuf messages for streaming responses. The interceptor is\n            free to modify this in some way, however.\n        \"\"\"\n        return method(request_or_iterator, context)\n\n    # Implementation of grpc.ServerInterceptor, do not override.\n    def intercept_service(self, continuation, handler_call_details):\n        \"\"\"Implementation of grpc.ServerInterceptor.\n\n        This is not part of the grpc_interceptor.ServerInterceptor API, but must have\n        a public name. Do not override it, unless you know what you're doing.\n        \"\"\"\n        next_handler = continuation(handler_call_details)\n        # Returns None if the method isn't implemented.\n        if next_handler is None:\n            return\n\n        handler_factory, next_handler_method = _get_factory_and_method(next_handler)\n\n        def invoke_intercept_method(request_or_iterator, context):\n            method_name = handler_call_details.method\n            return self.intercept(\n                next_handler_method,\n                request_or_iterator,\n                context,\n                method_name,\n            )\n\n        return handler_factory(\n            invoke_intercept_method,\n            request_deserializer=next_handler.request_deserializer,\n            response_serializer=next_handler.response_serializer,\n        )\n\n\nclass AsyncServerInterceptor(grpc_aio.ServerInterceptor, metaclass=abc.ABCMeta):\n    \"\"\"Base class for asyncio server-side interceptors.\n\n    To implement an interceptor, subclass this class and override the intercept method.\n    \"\"\"\n\n    @abc.abstractmethod\n    async def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc_aio.ServicerContext,\n        method_name: str,\n    ) -> Any:  # pragma: no cover\n        \"\"\"Override this method to implement a custom interceptor.\n\n        You should await method(request_or_iterator, context) to invoke the next handler\n        (either the RPC method implementation, or the next interceptor in the list).\n\n        Args:\n            method: Either the RPC method implementation, or the next interceptor in\n                the chain.\n            request_or_iterator: The RPC request, as a protobuf message if it is a\n                unary request, or an iterator of protobuf messages if it is a streaming\n                request.\n            context: The ServicerContext pass by gRPC to the service.\n            method_name: A string of the form \"/protobuf.package.Service/Method\"\n\n        Returns:\n            This should return the result of method(request_or_iterator, context),\n            which is typically the RPC method response, as a protobuf message. The\n            interceptor is free to modify this in some way, however.\n        \"\"\"\n        response_or_iterator = method(request_or_iterator, context)\n        if hasattr(response_or_iterator, \"__aiter__\"):\n            return response_or_iterator\n        else:\n            return await response_or_iterator\n\n    # Implementation of grpc.ServerInterceptor, do not override.\n    async def intercept_service(self, continuation, handler_call_details):\n        \"\"\"Implementation of grpc.aio.ServerInterceptor.\n\n        This is not part of the grpc_interceptor.AsyncServerInterceptor API, but must\n        have a public name. Do not override it, unless you know what you're doing.\n        \"\"\"\n        next_handler = await continuation(handler_call_details)\n        # Returns None if the method isn't implemented.\n        if not next_handler:\n            return\n\n        handler_factory, next_handler_method = _get_factory_and_method(next_handler)\n\n        if next_handler.response_streaming:\n\n            async def invoke_intercept_method(request, context):\n                method_name = handler_call_details.method\n                coroutine_or_asyncgen = self.intercept(\n                    next_handler_method,\n                    request,\n                    context,\n                    method_name,\n                )\n\n                # Async server streaming handlers return async_generator, because they\n                # use the async def + yield syntax. However, this is NOT a coroutine\n                # and hence is not awaitable. This can be a problem if the interceptor\n                # ignores the individual streaming response items and simply returns the\n                # result of method(request, context). In that case the interceptor IS a\n                # coroutine, and hence should be awaited. In both cases, we need\n                # something we can iterate over so that THIS function is an\n                # async_generator like the actual RPC method.\n                if iscoroutine(coroutine_or_asyncgen):\n                    asyncgen_or_none = await coroutine_or_asyncgen\n                    # If a handler is using the read/write API, it will return None.\n                    if not asyncgen_or_none:\n                        return\n                    asyncgen = asyncgen_or_none\n                else:\n                    asyncgen = coroutine_or_asyncgen\n\n                async for r in asyncgen:\n                    yield r\n\n        else:\n\n            async def invoke_intercept_method(request, context):\n                method_name = handler_call_details.method\n                return await self.intercept(\n                    next_handler_method,\n                    request,\n                    context,\n                    method_name,\n                )\n\n        return handler_factory(\n            invoke_intercept_method,\n            request_deserializer=next_handler.request_deserializer,\n            response_serializer=next_handler.response_serializer,\n        )\n\n\ndef _get_factory_and_method(\n    rpc_handler: grpc.RpcMethodHandler,\n) -> Tuple[Callable, Callable]:\n    if rpc_handler.unary_unary:\n        return grpc.unary_unary_rpc_method_handler, rpc_handler.unary_unary\n    elif rpc_handler.unary_stream:\n        return grpc.unary_stream_rpc_method_handler, rpc_handler.unary_stream\n    elif rpc_handler.stream_unary:\n        return grpc.stream_unary_rpc_method_handler, rpc_handler.stream_unary\n    elif rpc_handler.stream_stream:\n        return grpc.stream_stream_rpc_method_handler, rpc_handler.stream_stream\n    else:  # pragma: no cover\n        raise RuntimeError(\"RPC handler implementation does not exist\")\n\n\nclass MethodName:\n    \"\"\"Represents a gRPC method name.\n\n    gRPC methods are defined by three parts, represented by the three attributes.\n\n    Attributes:\n        package: This is defined by the `package foo.bar;` designation in the protocol\n            buffer definition, or it could be defined by the protocol buffer directory\n            structure, depending on the language\n            (see https://developers.google.com/protocol-buffers/docs/proto3#packages).\n        service: This is the service name in the protocol buffer definition (e.g.,\n            `service SearchService { ... }`.\n        method: This is the method name. (e.g., `rpc Search(...) returns (...);`).\n    \"\"\"\n\n    def __init__(self, package: str, service: str, method: str):\n        self.package = package\n        self.service = service\n        self.method = method\n\n    def __repr__(self) -> str:\n        \"\"\"Object-like representation.\"\"\"\n        return (\n            f\"MethodName(package='{self.package}', service='{self.service}',\"\n            f\" method='{self.method}')\"\n        )\n\n    @property\n    def fully_qualified_service(self):\n        \"\"\"Return the service name prefixed with the package.\n\n        Example:\n            >>> MethodName(\"foo.bar\", \"SearchService\", \"Search\").fully_qualified_service\n            'foo.bar.SearchService'\n        \"\"\"\n        return f\"{self.package}.{self.service}\" if self.package else self.service\n\n\ndef parse_method_name(method_name: str) -> MethodName:\n    \"\"\"Parse a method name into package, service and endpoint components.\n\n    Arguments:\n        method_name: A string of the form \"/foo.bar.SearchService/Search\", as passed to\n            ServerInterceptor.intercept().\n\n    Returns:\n        A MethodName object.\n\n    Example:\n        >>> parse_method_name(\"/foo.bar.SearchService/Search\")\n        MethodName(package='foo.bar', service='SearchService', method='Search')\n    \"\"\"\n    _, package_and_service, method = method_name.split(\"/\")\n    *maybe_package, service = package_and_service.rsplit(\".\", maxsplit=1)\n    package = maybe_package[0] if maybe_package else \"\"\n    return MethodName(package, service, method)\n"
  },
  {
    "path": "src/grpc_interceptor/testing/__init__.py",
    "content": "\"\"\"A framework for testing interceptors.\"\"\"\n\nfrom typing import Callable\n\nfrom grpc_interceptor.testing.dummy_client import (\n    dummy_client,\n    DummyService,\n)\nfrom grpc_interceptor.testing.protos.dummy_pb2 import DummyRequest, DummyResponse\n\n\n__all__ = [\n    \"dummy_client\",\n    \"DummyRequest\",\n    \"DummyResponse\",\n    \"DummyService\",\n    \"raises\",\n]\n\n\ndef raises(e: Exception) -> Callable:\n    \"\"\"Return a function that raises the given exception when called.\n\n    Args:\n        e: The exception to be raised.\n\n    Returns:\n        A function that can take any arguments, and raises the given exception.\n    \"\"\"\n\n    def f(*args, **kwargs):\n        raise (e)\n\n    return f\n"
  },
  {
    "path": "src/grpc_interceptor/testing/dummy_client.py",
    "content": "\"\"\"Defines a service and client for testing interceptors.\"\"\"\n\nimport asyncio\nfrom concurrent import futures\nfrom contextlib import contextmanager\nfrom inspect import iscoroutine\nfrom threading import Event, Thread\nfrom typing import (\n    Any,\n    AsyncGenerator,\n    AsyncIterable,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Optional,\n    Union\n)\n\nimport grpc\n\nfrom grpc_interceptor.client import ClientInterceptor\nfrom grpc_interceptor.server import AsyncServerInterceptor, grpc_aio, ServerInterceptor\nfrom grpc_interceptor.testing.protos import dummy_pb2_grpc\nfrom grpc_interceptor.testing.protos.dummy_pb2 import DummyRequest, DummyResponse\n\nSpecialCaseFunction = Callable[\n    [str, Union[grpc.ServicerContext, grpc_aio.ServicerContext]], str\n]\n\n\nclass _SpecialCaseMixin:\n    _special_cases: Dict[str, SpecialCaseFunction]\n\n    def _get_output(self, request: DummyRequest, context: grpc.ServicerContext) -> str:\n        input = request.input\n\n        output = input\n        if input in self._special_cases:\n            output = self._special_cases[input](input, context)\n\n        return output\n\n    async def _get_output_async(\n        self,\n        request: DummyRequest,\n        context: grpc_aio.ServicerContext\n    ) -> str:\n        input = request.input\n\n        output = input\n        if input in self._special_cases:\n            output = self._special_cases[input](input, context)\n            if iscoroutine(output):\n                output = await output\n\n        return output\n\n\nclass DummyService(dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin):\n    \"\"\"A gRPC service used for testing.\n\n    Args:\n        special_cases: A dictionary where the keys are strings, and the values are\n            functions that take and return strings. The functions can also raise\n            exceptions. When the Execute method is given a string in the dict, it\n            will call the function with that string instead, and return the result.\n            This allows testing special cases, like raising exceptions.\n    \"\"\"\n\n    def __init__(\n        self,\n        special_cases: Dict[str, SpecialCaseFunction],\n    ):\n        self._special_cases = special_cases\n\n    def Execute(\n        self, request: DummyRequest, context: grpc.ServicerContext\n    ) -> DummyResponse:\n        \"\"\"Echo the input, or take on of the special cases actions.\"\"\"\n        return DummyResponse(output=self._get_output(request, context))\n\n    def ExecuteClientStream(\n        self, request_iter: Iterable[DummyRequest], context: grpc.ServicerContext\n    ) -> DummyResponse:\n        \"\"\"Iterate over the input and concatenates the strings into the output.\"\"\"\n        output = \"\".join(self._get_output(request, context) for request in request_iter)\n        return DummyResponse(output=output)\n\n    def ExecuteServerStream(\n        self, request: DummyRequest, context: grpc.ServicerContext\n    ) -> Iterable[DummyResponse]:\n        \"\"\"Stream one character at a time from the input.\"\"\"\n        for c in self._get_output(request, context):\n            yield DummyResponse(output=c)\n\n    def ExecuteClientServerStream(\n        self, request_iter: Iterable[DummyRequest], context: grpc.ServicerContext\n    ) -> Iterable[DummyResponse]:\n        \"\"\"Stream input to output.\"\"\"\n        for request in request_iter:\n            yield DummyResponse(output=self._get_output(request, context))\n\n\nclass AsyncDummyService(dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin):\n    \"\"\"A gRPC service used for testing, similar to DummyService except async.\n\n    See DummyService for more info.\n    \"\"\"\n\n    def __init__(\n        self,\n        special_cases: Dict[str, SpecialCaseFunction],\n    ):\n        self._special_cases = special_cases\n\n    async def Execute(\n        self, request: DummyRequest, context: grpc_aio.ServicerContext\n    ) -> DummyResponse:\n        \"\"\"Echo the input, or take on of the special cases actions.\"\"\"\n        return DummyResponse(output=await self._get_output_async(request, context))\n\n    async def ExecuteClientStream(\n        self,\n        request_iter: AsyncIterable[DummyRequest],\n        context: grpc_aio.ServicerContext,\n    ) -> DummyResponse:\n        \"\"\"Iterate over the input and concatenates the strings into the output.\"\"\"\n        output = \"\".join([\n            await self._get_output_async(request, context)\n            async for request in request_iter\n        ])  # noqa: E501\n        return DummyResponse(output=output)\n\n    async def ExecuteServerStream(\n        self, request: DummyRequest, context: grpc_aio.ServicerContext\n    ) -> AsyncGenerator[DummyResponse, None]:\n        \"\"\"Stream one character at a time from the input.\"\"\"\n        for c in await self._get_output_async(request, context):\n            yield DummyResponse(output=c)\n\n    async def ExecuteClientServerStream(\n        self,\n        request_iter: AsyncIterable[DummyRequest],\n        context: grpc_aio.ServicerContext,\n    ) -> AsyncGenerator[DummyResponse, None]:\n        \"\"\"Stream input to output.\"\"\"\n        async for request in request_iter:\n            yield DummyResponse(output=await self._get_output_async(request, context))\n\n\nclass AsyncReadWriteDummyService(\n    dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin\n):\n    \"\"\"Similar to AsyncDummyService except uses the read / write API.\n\n    See DummyService for more info.\n    \"\"\"\n\n    def __init__(\n        self,\n        special_cases: Dict[str, SpecialCaseFunction],\n    ):\n        self._special_cases = special_cases\n\n    async def Execute(\n        self, request: DummyRequest, context: grpc_aio.ServicerContext\n    ) -> DummyResponse:\n        \"\"\"Echo the input, or take on of the special cases actions.\"\"\"\n        return DummyResponse(output=await self._get_output_async(request, context))\n\n    async def ExecuteClientStream(\n        self,\n        unused_request: Any,\n        context: grpc_aio.ServicerContext,\n    ) -> DummyResponse:\n        \"\"\"Iterate over the input and concatenates the strings into the output.\"\"\"\n        output = []\n        while True:\n            request = await context.read()\n            if request == grpc_aio.EOF:\n                break\n            output.append(await self._get_output_async(request, context))\n\n        return DummyResponse(output=\"\".join(output))\n\n    async def ExecuteServerStream(\n        self, request: DummyRequest, context: grpc_aio.ServicerContext\n    ) -> None:\n        \"\"\"Stream one character at a time from the input.\"\"\"\n        for c in await self._get_output_async(request, context):\n            await context.write(DummyResponse(output=c))\n\n    async def ExecuteClientServerStream(\n        self,\n        request_iter: AsyncIterable[DummyRequest],\n        context: grpc_aio.ServicerContext,\n    ) -> None:\n        \"\"\"Stream input to output.\"\"\"\n        while True:\n            request = await context.read()\n            if request == grpc_aio.EOF:\n                break\n            await context.write(\n                DummyResponse(output=await self._get_output_async(request, context))\n            )\n\n\n@contextmanager\ndef dummy_client(\n    special_cases: Dict[str, SpecialCaseFunction],\n    interceptors: Optional[List[ServerInterceptor]] = None,\n    client_interceptors: Optional[List[ClientInterceptor]] = None,\n    aio_server: bool = False,\n    aio_client: bool = False,\n    aio_read_write: bool = False,\n):\n    \"\"\"A context manager that returns a gRPC client connected to a DummyService.\"\"\"\n    # Sanity check that the interceptors are async if using an async server,\n    # otherwise the tests will just hang.\n    for intr in interceptors or []:\n        if aio_server != isinstance(intr, AsyncServerInterceptor):\n            raise TypeError(\"Set aio_server correctly\")\n    with dummy_channel(\n        special_cases,\n        interceptors,\n        client_interceptors,\n        aio_server=aio_server,\n        aio_client=aio_client,\n        aio_read_write=aio_read_write,\n    ) as channel:\n        client = dummy_pb2_grpc.DummyServiceStub(channel)\n        yield client\n\n\n@contextmanager\ndef dummy_channel(\n    special_cases: Dict[str, SpecialCaseFunction],\n    interceptors: Optional[List[ServerInterceptor]] = None,\n    client_interceptors: Optional[List[ClientInterceptor]] = None,\n    aio_server: bool = False,\n    aio_client: bool = False,\n    aio_read_write: bool = False,\n):\n    \"\"\"A context manager that returns a gRPC channel connected to a DummyService.\"\"\"\n    if not interceptors:\n        interceptors = []\n\n    if aio_server:\n        service = (\n            AsyncReadWriteDummyService(special_cases)\n            if aio_read_write\n            else AsyncDummyService(special_cases)\n        )\n        aio_loop = asyncio.new_event_loop()\n        aio_thread = _AsyncServerThread(\n            aio_loop,\n            service,\n            interceptors,\n        )\n        aio_thread.start()\n        aio_thread.wait_for_server()\n        port = aio_thread.port\n    else:\n        dummy_service = DummyService(special_cases)\n        server = grpc.server(\n            futures.ThreadPoolExecutor(max_workers=1), interceptors=interceptors\n        )\n        dummy_pb2_grpc.add_DummyServiceServicer_to_server(dummy_service, server)\n        port = server.add_insecure_port(\"localhost:0\")\n        server.start()\n\n    channel_descriptor = f\"localhost:{port}\"\n\n    if aio_client:\n        channel = grpc_aio.insecure_channel(channel_descriptor)\n        # Client interceptors might work, but I haven't tested them yet.\n        if client_interceptors:\n            raise TypeError(\"Client interceptors not supported with async channel\")\n        # We close the channel in _AsyncServerThread because we need to await\n        # it, and doing that in this thread is problematic because dummy_client\n        # isn't always used in an async context. We could get around that by\n        # creating a new loop or something, but will be lazy and use the server\n        # thread / loop for now.\n        if not aio_server:\n            raise ValueError(\"aio_server must be True if aio_client is True\")\n        aio_thread.async_channel = channel\n    else:\n        channel = grpc.insecure_channel(channel_descriptor)\n        if client_interceptors:\n            channel = grpc.intercept_channel(channel, *client_interceptors)\n\n    try:\n        yield channel\n    finally:\n        if not aio_client:\n            # async channel is closed by _AsyncServerThread\n            channel.close()\n\n        if aio_server:\n            aio_thread.stop()\n            aio_thread.join()\n        else:\n            server.stop(None)\n\n\nclass _AsyncServerThread(Thread):\n    port: int = 0\n    async_channel = None\n\n    def __init__(\n        self,\n        loop: asyncio.AbstractEventLoop,\n        service,\n        interceptors: List[AsyncServerInterceptor],\n    ):\n        super().__init__()\n        self.__loop = loop\n        self.__service = service\n        self.__interceptors = interceptors\n        self.__started = Event()\n\n    def run(self):\n        asyncio.set_event_loop(self.__loop)\n        self.__loop.run_until_complete(self.__run_server())\n\n    async def __run_server(self):\n        self.__server = grpc_aio.server(interceptors=tuple(self.__interceptors))\n        dummy_pb2_grpc.add_DummyServiceServicer_to_server(self.__service, self.__server)\n        self.port = self.__server.add_insecure_port(\"localhost:0\")\n        await self.__server.start()\n        self.__started.set()\n        await self.__server.wait_for_termination()\n        if self.async_channel:\n            await self.async_channel.close()\n\n    def wait_for_server(self):\n        self.__started.wait()\n\n    def stop(self):\n        self.__loop.call_soon_threadsafe(\n            lambda: asyncio.ensure_future(self.__shutdown())\n        )\n\n    async def __shutdown(self) -> None:\n        await self.__server.stop(None)\n        self.__loop.stop()\n"
  },
  {
    "path": "src/grpc_interceptor/testing/protos/__init__.py",
    "content": "\"\"\"Protobuf definitions for testing.\"\"\"\n"
  },
  {
    "path": "src/grpc_interceptor/testing/protos/dummy.proto",
    "content": "syntax = \"proto3\";\n\nmessage DummyRequest {\n    string input = 1;\n}\n\nmessage DummyResponse {\n    string output = 1;\n}\n\nservice DummyService {\n    rpc Execute (DummyRequest) returns (DummyResponse);\n    rpc ExecuteClientStream (stream DummyRequest) returns (DummyResponse);\n    rpc ExecuteServerStream (DummyRequest) returns (stream DummyResponse);\n    rpc ExecuteClientServerStream (stream DummyRequest) returns (stream DummyResponse);\n}\n"
  },
  {
    "path": "src/grpc_interceptor/testing/protos/dummy_pb2.py",
    "content": "# -*- coding: utf-8 -*-\n# Generated by the protocol buffer compiler.  DO NOT EDIT!\n# source: dummy.proto\n\"\"\"Generated protocol buffer code.\"\"\"\nfrom google.protobuf import descriptor as _descriptor\nfrom google.protobuf import descriptor_pool as _descriptor_pool\nfrom google.protobuf import message as _message\nfrom google.protobuf import reflection as _reflection\nfrom google.protobuf import symbol_database as _symbol_database\n\n# @@protoc_insertion_point(imports)\n\n_sym_db = _symbol_database.Default()\n\n\nDESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(\n    b'\\n\\x0b\\x64ummy.proto\"\\x1d\\n\\x0c\\x44ummyRequest\\x12\\r\\n\\x05input\\x18\\x01 \\x01(\\t\"\\x1f\\n\\rDummyResponse\\x12\\x0e\\n\\x06output\\x18\\x01 \\x01(\\t2\\xe8\\x01\\n\\x0c\\x44ummyService\\x12(\\n\\x07\\x45xecute\\x12\\r.DummyRequest\\x1a\\x0e.DummyResponse\\x12\\x36\\n\\x13\\x45xecuteClientStream\\x12\\r.DummyRequest\\x1a\\x0e.DummyResponse(\\x01\\x12\\x36\\n\\x13\\x45xecuteServerStream\\x12\\r.DummyRequest\\x1a\\x0e.DummyResponse0\\x01\\x12>\\n\\x19\\x45xecuteClientServerStream\\x12\\r.DummyRequest\\x1a\\x0e.DummyResponse(\\x01\\x30\\x01\\x62\\x06proto3'\n)\n\n\n_DUMMYREQUEST = DESCRIPTOR.message_types_by_name[\"DummyRequest\"]\n_DUMMYRESPONSE = DESCRIPTOR.message_types_by_name[\"DummyResponse\"]\nDummyRequest = _reflection.GeneratedProtocolMessageType(\n    \"DummyRequest\",\n    (_message.Message,),\n    {\n        \"DESCRIPTOR\": _DUMMYREQUEST,\n        \"__module__\": \"dummy_pb2\"\n        # @@protoc_insertion_point(class_scope:DummyRequest)\n    },\n)\n_sym_db.RegisterMessage(DummyRequest)\n\nDummyResponse = _reflection.GeneratedProtocolMessageType(\n    \"DummyResponse\",\n    (_message.Message,),\n    {\n        \"DESCRIPTOR\": _DUMMYRESPONSE,\n        \"__module__\": \"dummy_pb2\"\n        # @@protoc_insertion_point(class_scope:DummyResponse)\n    },\n)\n_sym_db.RegisterMessage(DummyResponse)\n\n_DUMMYSERVICE = DESCRIPTOR.services_by_name[\"DummyService\"]\nif _descriptor._USE_C_DESCRIPTORS == False:\n    DESCRIPTOR._options = None\n    _DUMMYREQUEST._serialized_start = 15\n    _DUMMYREQUEST._serialized_end = 44\n    _DUMMYRESPONSE._serialized_start = 46\n    _DUMMYRESPONSE._serialized_end = 77\n    _DUMMYSERVICE._serialized_start = 80\n    _DUMMYSERVICE._serialized_end = 312\n# @@protoc_insertion_point(module_scope)\n"
  },
  {
    "path": "src/grpc_interceptor/testing/protos/dummy_pb2.pyi",
    "content": "# @generated by generate_proto_mypy_stubs.py.  Do not edit!\nimport sys\nfrom google.protobuf.descriptor import (\n    Descriptor as google___protobuf___descriptor___Descriptor,\n    FileDescriptor as google___protobuf___descriptor___FileDescriptor,\n)\n\nfrom google.protobuf.message import Message as google___protobuf___message___Message\n\nfrom typing import (\n    Optional as typing___Optional,\n    Text as typing___Text,\n)\n\nfrom typing_extensions import Literal as typing_extensions___Literal\n\nbuiltin___bool = bool\nbuiltin___bytes = bytes\nbuiltin___float = float\nbuiltin___int = int\n\nDESCRIPTOR: google___protobuf___descriptor___FileDescriptor = ...\n\nclass DummyRequest(google___protobuf___message___Message):\n    DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...\n    input: typing___Text = ...\n    def __init__(\n        self,\n        *,\n        input: typing___Optional[typing___Text] = None,\n    ) -> None: ...\n    def ClearField(\n        self, field_name: typing_extensions___Literal[\"input\", b\"input\"]\n    ) -> None: ...\n\ntype___DummyRequest = DummyRequest\n\nclass DummyResponse(google___protobuf___message___Message):\n    DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...\n    output: typing___Text = ...\n    def __init__(\n        self,\n        *,\n        output: typing___Optional[typing___Text] = None,\n    ) -> None: ...\n    def ClearField(\n        self, field_name: typing_extensions___Literal[\"output\", b\"output\"]\n    ) -> None: ...\n\ntype___DummyResponse = DummyResponse\n"
  },
  {
    "path": "src/grpc_interceptor/testing/protos/dummy_pb2_grpc.py",
    "content": "# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!\n\"\"\"Client and server classes corresponding to protobuf-defined services.\"\"\"\nimport grpc\n\nfrom grpc_interceptor.testing.protos import (\n    dummy_pb2 as grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2,\n)\n\n\nclass DummyServiceStub(object):\n    \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n\n    def __init__(self, channel):\n        \"\"\"Constructor.\n\n        Args:\n            channel: A grpc.Channel.\n        \"\"\"\n        self.Execute = channel.unary_unary(\n            \"/DummyService/Execute\",\n            request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n        )\n        self.ExecuteClientStream = channel.stream_unary(\n            \"/DummyService/ExecuteClientStream\",\n            request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n        )\n        self.ExecuteServerStream = channel.unary_stream(\n            \"/DummyService/ExecuteServerStream\",\n            request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n        )\n        self.ExecuteClientServerStream = channel.stream_stream(\n            \"/DummyService/ExecuteClientServerStream\",\n            request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n        )\n\n\nclass DummyServiceServicer(object):\n    \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n\n    def Execute(self, request, context):\n        \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n        context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n        context.set_details(\"Method not implemented!\")\n        raise NotImplementedError(\"Method not implemented!\")\n\n    def ExecuteClientStream(self, request_iterator, context):\n        \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n        context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n        context.set_details(\"Method not implemented!\")\n        raise NotImplementedError(\"Method not implemented!\")\n\n    def ExecuteServerStream(self, request, context):\n        \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n        context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n        context.set_details(\"Method not implemented!\")\n        raise NotImplementedError(\"Method not implemented!\")\n\n    def ExecuteClientServerStream(self, request_iterator, context):\n        \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n        context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n        context.set_details(\"Method not implemented!\")\n        raise NotImplementedError(\"Method not implemented!\")\n\n\ndef add_DummyServiceServicer_to_server(servicer, server):\n    rpc_method_handlers = {\n        \"Execute\": grpc.unary_unary_rpc_method_handler(\n            servicer.Execute,\n            request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,\n            response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,\n        ),\n        \"ExecuteClientStream\": grpc.stream_unary_rpc_method_handler(\n            servicer.ExecuteClientStream,\n            request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,\n            response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,\n        ),\n        \"ExecuteServerStream\": grpc.unary_stream_rpc_method_handler(\n            servicer.ExecuteServerStream,\n            request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,\n            response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,\n        ),\n        \"ExecuteClientServerStream\": grpc.stream_stream_rpc_method_handler(\n            servicer.ExecuteClientServerStream,\n            request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,\n            response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,\n        ),\n    }\n    generic_handler = grpc.method_handlers_generic_handler(\n        \"DummyService\", rpc_method_handlers\n    )\n    server.add_generic_rpc_handlers((generic_handler,))\n\n\n# This class is part of an EXPERIMENTAL API.\nclass DummyService(object):\n    \"\"\"Missing associated documentation comment in .proto file.\"\"\"\n\n    @staticmethod\n    def Execute(\n        request,\n        target,\n        options=(),\n        channel_credentials=None,\n        call_credentials=None,\n        insecure=False,\n        compression=None,\n        wait_for_ready=None,\n        timeout=None,\n        metadata=None,\n    ):\n        return grpc.experimental.unary_unary(\n            request,\n            target,\n            \"/DummyService/Execute\",\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n            options,\n            channel_credentials,\n            insecure,\n            call_credentials,\n            compression,\n            wait_for_ready,\n            timeout,\n            metadata,\n        )\n\n    @staticmethod\n    def ExecuteClientStream(\n        request_iterator,\n        target,\n        options=(),\n        channel_credentials=None,\n        call_credentials=None,\n        insecure=False,\n        compression=None,\n        wait_for_ready=None,\n        timeout=None,\n        metadata=None,\n    ):\n        return grpc.experimental.stream_unary(\n            request_iterator,\n            target,\n            \"/DummyService/ExecuteClientStream\",\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n            options,\n            channel_credentials,\n            insecure,\n            call_credentials,\n            compression,\n            wait_for_ready,\n            timeout,\n            metadata,\n        )\n\n    @staticmethod\n    def ExecuteServerStream(\n        request,\n        target,\n        options=(),\n        channel_credentials=None,\n        call_credentials=None,\n        insecure=False,\n        compression=None,\n        wait_for_ready=None,\n        timeout=None,\n        metadata=None,\n    ):\n        return grpc.experimental.unary_stream(\n            request,\n            target,\n            \"/DummyService/ExecuteServerStream\",\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n            options,\n            channel_credentials,\n            insecure,\n            call_credentials,\n            compression,\n            wait_for_ready,\n            timeout,\n            metadata,\n        )\n\n    @staticmethod\n    def ExecuteClientServerStream(\n        request_iterator,\n        target,\n        options=(),\n        channel_credentials=None,\n        call_credentials=None,\n        insecure=False,\n        compression=None,\n        wait_for_ready=None,\n        timeout=None,\n        metadata=None,\n    ):\n        return grpc.experimental.stream_stream(\n            request_iterator,\n            target,\n            \"/DummyService/ExecuteClientServerStream\",\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,\n            grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,\n            options,\n            channel_credentials,\n            insecure,\n            call_credentials,\n            compression,\n            wait_for_ready,\n            timeout,\n            metadata,\n        )\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "\"\"\"Test suite for grpc-interceptor.\"\"\"\n"
  },
  {
    "path": "tests/test_client.py",
    "content": "\"\"\"Test cases for the grpc-interceptor base ClientInterceptor.\"\"\"\n\nfrom collections import defaultdict\nimport itertools\nfrom typing import List, Tuple\n\nimport grpc\nimport pytest\n\nfrom grpc_interceptor import ClientInterceptor\nfrom grpc_interceptor.testing import dummy_client, DummyRequest, raises\n\n\nclass MetadataInterceptor(ClientInterceptor):\n    \"\"\"A test interceptor that injects invocation metadata.\"\"\"\n\n    def __init__(self, metadata: List[Tuple[str, str]]):\n        self._metadata = metadata\n\n    def intercept(self, method, request_or_iterator, call_details):\n        \"\"\"Add invocation metadata to request.\"\"\"\n        new_details = call_details._replace(metadata=self._metadata)\n        return method(request_or_iterator, new_details)\n\n\nclass CodeCountInterceptor(ClientInterceptor):\n    \"\"\"Test interceptor that counts status codes returned by the server.\"\"\"\n\n    def __init__(self):\n        self.counts = defaultdict(int)\n\n    def intercept(self, method, request_or_iterator, call_details):\n        \"\"\"Call continuation and count status codes.\"\"\"\n        future = method(request_or_iterator, call_details)\n        self.counts[future.code()] += 1\n        return future\n\n\nclass RetryInterceptor(ClientInterceptor):\n    \"\"\"Test interceptor that retries failed RPCs.\"\"\"\n\n    def __init__(self, retries):\n        self._retries = retries\n\n    def intercept(self, method, request_or_iterator, call_details):\n        \"\"\"Call the continuation and retry up to retries times if it fails.\"\"\"\n        tries_remaining = 1 + self._retries\n        while 0 < tries_remaining:\n            future = method(request_or_iterator, call_details)\n            try:\n                future.result()\n                return future\n            except Exception:\n                tries_remaining -= 1\n\n        return future\n\n\nclass CrashingService:\n    \"\"\"Special case function that raises a given number of times before succeeding.\"\"\"\n\n    DEFAULT_EXCEPTION = ValueError(\"oops\")\n\n    def __init__(self, num_crashes, success_value=\"OK\", exception=DEFAULT_EXCEPTION):\n        self._num_crashes = num_crashes\n        self._success_value = success_value\n        self._exception = exception\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"Raise the first num_crashes times called, then return success_value.\"\"\"\n        if 0 < self._num_crashes:\n            self._num_crashes -= 1\n            raise self._exception\n\n        return self._success_value\n\n\nclass CachingInterceptor(ClientInterceptor):\n    \"\"\"A test interceptor that caches responses based on input string.\"\"\"\n\n    def __init__(self):\n        self._cache = {}\n\n    def intercept(self, method, request_or_iterator, call_details):\n        \"\"\"Cache responses based on input string.\"\"\"\n        if hasattr(request_or_iterator, \"__iter__\"):\n            request_or_iterator, copy_iterator = itertools.tee(request_or_iterator)\n            cache_key = tuple(r.input for r in copy_iterator)\n        else:\n            cache_key = request_or_iterator.input\n\n        if cache_key not in self._cache:\n            self._cache[cache_key] = method(request_or_iterator, call_details)\n\n        return self._cache[cache_key]\n\n\n@pytest.fixture\ndef metadata_string():\n    \"\"\"Expected joined metadata string.\"\"\"\n    return \"this_key:this_value\"\n\n\n@pytest.fixture\ndef metadata_client():\n    \"\"\"Client with metadata interceptor.\"\"\"\n    intr = MetadataInterceptor([(\"this_key\", \"this_value\")])\n    interceptors = [intr]\n\n    special_cases = {\n        \"metadata\": lambda _, c: \",\".join(\n            f\"{key}:{value}\" for key, value in c.invocation_metadata()\n        )\n    }\n    with dummy_client(\n        special_cases=special_cases, client_interceptors=interceptors\n    ) as client:\n        yield client\n\n\ndef test_metadata_unary(metadata_client, metadata_string):\n    \"\"\"Invocation metadata should be added to the servicer context.\"\"\"\n    unary_output = metadata_client.Execute(DummyRequest(input=\"metadata\")).output\n    assert metadata_string in unary_output\n\n\ndef test_metadata_server_stream(metadata_client, metadata_string):\n    \"\"\"Invocation metadata should be added to the servicer context.\"\"\"\n    server_stream_output = [\n        r.output\n        for r in metadata_client.ExecuteServerStream(DummyRequest(input=\"metadata\"))\n    ]\n    assert metadata_string in \"\".join(server_stream_output)\n\n\ndef test_metadata_client_stream(metadata_client, metadata_string):\n    \"\"\"Invocation metadata should be added to the servicer context.\"\"\"\n    client_stream_input = iter((DummyRequest(input=\"metadata\"),))\n    client_stream_output = metadata_client.ExecuteClientStream(\n        client_stream_input\n    ).output\n    assert metadata_string in client_stream_output\n\n\ndef test_metadata_client_server_stream(metadata_client, metadata_string):\n    \"\"\"Invocation metadata should be added to the servicer context.\"\"\"\n    stream_stream_input = iter((DummyRequest(input=\"metadata\"),))\n    result = metadata_client.ExecuteClientServerStream(stream_stream_input)\n    stream_stream_output = [r.output for r in result]\n    assert metadata_string in \"\".join(stream_stream_output)\n\n\ndef test_code_counting():\n    \"\"\"Access to code on call details works correctly.\"\"\"\n    interceptor = CodeCountInterceptor()\n    special_cases = {\"error\": raises(ValueError(\"oops\"))}\n    with dummy_client(\n        special_cases=special_cases, client_interceptors=[interceptor]\n    ) as client:\n        assert interceptor.counts == {}\n        client.Execute(DummyRequest(input=\"foo\"))\n        assert interceptor.counts == {grpc.StatusCode.OK: 1}\n        with pytest.raises(grpc.RpcError):\n            client.Execute(DummyRequest(input=\"error\"))\n        assert interceptor.counts == {grpc.StatusCode.OK: 1, grpc.StatusCode.UNKNOWN: 1}\n\n\ndef test_basic_retry():\n    \"\"\"Calling the continuation multiple times should work.\"\"\"\n    interceptor = RetryInterceptor(retries=1)\n    special_cases = {\"error_once\": CrashingService(num_crashes=1)}\n    with dummy_client(\n        special_cases=special_cases, client_interceptors=[interceptor]\n    ) as client:\n        assert client.Execute(DummyRequest(input=\"error_once\")).output == \"OK\"\n\n\ndef test_failed_retry():\n    \"\"\"The interceptor can return failed futures.\"\"\"\n    interceptor = RetryInterceptor(retries=1)\n    special_cases = {\"error_twice\": CrashingService(num_crashes=2)}\n    with dummy_client(\n        special_cases=special_cases, client_interceptors=[interceptor]\n    ) as client:\n        with pytest.raises(grpc.RpcError):\n            client.Execute(DummyRequest(input=\"error_twice\"))\n\n\ndef test_chaining():\n    \"\"\"Chaining interceptors should work.\"\"\"\n    retry_interceptor = RetryInterceptor(retries=1)\n    code_count_interceptor = CodeCountInterceptor()\n    interceptors = [retry_interceptor, code_count_interceptor]\n    special_cases = {\"error_once\": CrashingService(num_crashes=1)}\n    with dummy_client(\n        special_cases=special_cases, client_interceptors=interceptors\n    ) as client:\n        assert code_count_interceptor.counts == {}\n        assert client.Execute(DummyRequest(input=\"error_once\")).output == \"OK\"\n        assert code_count_interceptor.counts == {\n            grpc.StatusCode.OK: 1,\n            grpc.StatusCode.UNKNOWN: 1,\n        }\n\n\ndef test_caching():\n    \"\"\"Caching calls (not calling the continuation) should work.\"\"\"\n    caching_interceptor = CachingInterceptor()\n    # Use this to test how many times the continuation is called.\n    code_count_interceptor = CodeCountInterceptor()\n    interceptors = [caching_interceptor, code_count_interceptor]\n    with dummy_client(special_cases={}, client_interceptors=interceptors) as client:\n        assert code_count_interceptor.counts == {}\n        assert client.Execute(DummyRequest(input=\"hello\")).output == \"hello\"\n        assert code_count_interceptor.counts == {grpc.StatusCode.OK: 1}\n        assert client.Execute(DummyRequest(input=\"hello\")).output == \"hello\"\n        assert code_count_interceptor.counts == {grpc.StatusCode.OK: 1}\n        assert client.Execute(DummyRequest(input=\"goodbye\")).output == \"goodbye\"\n        assert code_count_interceptor.counts == {grpc.StatusCode.OK: 2}\n        # Try streaming requests\n        inputs = [\"foo\", \"bar\"]\n        input_iter = (DummyRequest(input=input) for input in inputs)\n        assert client.ExecuteClientStream(input_iter).output == \"foobar\"\n        assert code_count_interceptor.counts == {grpc.StatusCode.OK: 3}\n        input_iter = (DummyRequest(input=input) for input in inputs)\n        assert client.ExecuteClientStream(input_iter).output == \"foobar\"\n        assert code_count_interceptor.counts == {grpc.StatusCode.OK: 3}\n"
  },
  {
    "path": "tests/test_exception_to_status.py",
    "content": "\"\"\"Test cases for ExceptionToStatusInterceptor.\"\"\"\nimport re\nfrom typing import Any, List, Optional, Union\n\nimport grpc\nfrom grpc import aio as grpc_aio\nimport pytest\n\nfrom grpc_interceptor import exceptions as gx\nfrom grpc_interceptor.exception_to_status import (\n    AsyncExceptionToStatusInterceptor,\n    ExceptionToStatusInterceptor,\n)\nfrom grpc_interceptor.testing import dummy_client, DummyRequest, raises\n\n\nclass NonGrpcException(Exception):\n    \"\"\"An exception that does not derive from GrpcException.\"\"\"\n\n    TEST_STATUS_CODE = grpc.StatusCode.DATA_LOSS\n    TEST_DETAILS = \"Here's some custom details\"\n\n\nclass ExtendedExceptionToStatusInterceptor(ExceptionToStatusInterceptor):\n    \"\"\"A test case for extending ExceptionToStatusInterceptor.\"\"\"\n\n    def __init__(self):\n        self.caught_custom_exception = False\n\n    def handle_exception(self, ex, request_or_iterator, context, method_name):\n        \"\"\"Handles NonGrpcException in a special way.\"\"\"\n        if isinstance(ex, NonGrpcException):\n            self.caught_custom_exception = True\n            context.abort(\n                NonGrpcException.TEST_STATUS_CODE, NonGrpcException.TEST_DETAILS\n            )\n        else:\n            super().handle_exception(ex, request_or_iterator, context, method_name)\n\n\nclass AsyncExtendedExceptionToStatusInterceptor(AsyncExceptionToStatusInterceptor):\n    \"\"\"A test case for extending AsyncExceptionToStatusInterceptor.\"\"\"\n\n    def __init__(self):\n        self.caught_custom_exception = False\n\n    async def handle_exception(self, ex, request_or_iterator, context, method_name):\n        \"\"\"Handles NonGrpcException in a special way.\"\"\"\n        if isinstance(ex, NonGrpcException):\n            self.caught_custom_exception = True\n            await context.abort(\n                NonGrpcException.TEST_STATUS_CODE, NonGrpcException.TEST_DETAILS\n            )\n        else:\n            await super().handle_exception(\n                ex, request_or_iterator, context, method_name\n            )\n\n\ndef _get_interceptors(\n    aio: bool, status_on_unknown_exception: Optional[grpc.StatusCode] = None\n) -> List[Union[ExceptionToStatusInterceptor, AsyncExceptionToStatusInterceptor]]:\n    return (\n        [\n            AsyncExceptionToStatusInterceptor(\n                status_on_unknown_exception=status_on_unknown_exception\n            )\n        ]\n        if aio\n        else [\n            ExceptionToStatusInterceptor(\n                status_on_unknown_exception=status_on_unknown_exception\n            )\n        ]\n    )\n\n\ndef test_repr():\n    \"\"\"repr() should display the class name, status code, and details.\"\"\"\n    assert (\n        repr(gx.GrpcException(details=\"oops\"))\n        == \"GrpcException(status_code=UNKNOWN, details='oops')\"\n    )\n    assert (\n        repr(gx.GrpcException(status_code=grpc.StatusCode.NOT_FOUND, details=\"oops\"))\n        == \"GrpcException(status_code=NOT_FOUND, details='oops')\"\n    )\n    assert (\n        repr(gx.NotFound(details=\"?\")) == \"NotFound(status_code=NOT_FOUND, details='?')\"\n    )\n\n\ndef test_status_string():\n    \"\"\"status_string should be the string version of the status code.\"\"\"\n    assert gx.GrpcException().status_string == \"UNKNOWN\"\n    assert (\n        gx.GrpcException(status_code=grpc.StatusCode.NOT_FOUND).status_string\n        == \"NOT_FOUND\"\n    )\n    assert gx.NotFound().status_string == \"NOT_FOUND\"\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_no_exception(aio):\n    \"\"\"An RPC with no exceptions should work as if the interceptor wasn't there.\"\"\"\n    interceptors = _get_interceptors(aio)\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as client:\n        assert client.Execute(DummyRequest(input=\"foo\")).output == \"foo\"\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_custom_details(aio):\n    \"\"\"We can set custom details.\"\"\"\n    interceptors = _get_interceptors(aio)\n    special_cases = {\"error\": raises(gx.NotFound(details=\"custom\"))}\n    with dummy_client(\n        special_cases=special_cases, interceptors=interceptors, aio_server=aio\n    ) as client:\n        assert (\n            client.Execute(DummyRequest(input=\"foo\")).output == \"foo\"\n        )  # Test a happy path too\n        with pytest.raises(grpc.RpcError) as e:\n            client.Execute(DummyRequest(input=\"error\"))\n        assert e.value.code() == grpc.StatusCode.NOT_FOUND\n        assert e.value.details() == \"custom\"\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_non_grpc_exception(aio):\n    \"\"\"Exceptions other than GrpcExceptions are ignored.\"\"\"\n    interceptors = _get_interceptors(aio)\n    special_cases = {\"error\": raises(ValueError(\"oops\"))}\n    with dummy_client(\n        special_cases=special_cases, interceptors=interceptors, aio_server=aio\n    ) as client:\n        with pytest.raises(grpc.RpcError) as e:\n            client.Execute(DummyRequest(input=\"error\"))\n        assert e.value.code() == grpc.StatusCode.UNKNOWN\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_non_grpc_exception_with_override(aio):\n    \"\"\"We can set a custom status code when non-GrpcExceptions are raised.\"\"\"\n    interceptors = _get_interceptors(\n        aio, status_on_unknown_exception=grpc.StatusCode.INTERNAL\n    )\n    special_cases = {\"error\": raises(ValueError(\"oops\"))}\n    with dummy_client(\n        special_cases=special_cases, interceptors=interceptors, aio_server=aio\n    ) as client:\n        with pytest.raises(grpc.RpcError) as e:\n            client.Execute(DummyRequest(input=\"error\"))\n        assert e.value.code() == grpc.StatusCode.INTERNAL\n        assert re.fullmatch(r\"ValueError\\('oops',?\\)\", e.value.details())\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_aborted_context(aio):\n    \"\"\"If the context is aborted, the exception is propagated.\"\"\"\n    def error(request: Any, context: grpc.ServicerContext) -> None:\n        context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted')\n\n    async def async_error(request: Any, context: grpc_aio.ServicerContext) -> None:\n        await context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted')\n\n    interceptors = _get_interceptors(aio, grpc.StatusCode.INTERNAL)\n    special_cases = {\n        \"error\": async_error if aio else error\n    }\n\n    with dummy_client(\n        special_cases=special_cases, interceptors=interceptors, aio_server=aio\n    ) as client:\n        with pytest.raises(grpc.RpcError) as e:\n            client.Execute(DummyRequest(input=\"error\"))\n        assert e.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED\n        assert re.fullmatch(r\"resource exhausted\", e.value.details())\n\n\ndef test_override_with_ok():\n    \"\"\"We cannot set the default status code to OK.\"\"\"\n    with pytest.raises(ValueError):\n        ExceptionToStatusInterceptor(status_on_unknown_exception=grpc.StatusCode.OK)\n    with pytest.raises(ValueError):\n        AsyncExceptionToStatusInterceptor(\n            status_on_unknown_exception=grpc.StatusCode.OK\n        )\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_all_exceptions(aio):\n    \"\"\"Every gRPC status code is represented, and they each are unique.\n\n    Make sure we aren't missing any status codes, and that we didn't copy paste the\n    same status code or details into two different classes.\n    \"\"\"\n    interceptors = _get_interceptors(aio)\n    all_status_codes = {sc for sc in grpc.StatusCode if sc != grpc.StatusCode.OK}\n    seen_codes = set()\n    seen_details = set()\n\n    for sc in all_status_codes:\n        ex = getattr(gx, _snake_to_camel(sc.name))\n        assert ex\n        special_cases = {\"error\": raises(ex())}\n        with dummy_client(\n            special_cases=special_cases, interceptors=interceptors, aio_server=aio\n        ) as client:\n            with pytest.raises(grpc.RpcError) as e:\n                client.Execute(DummyRequest(input=\"error\"))\n            assert e.value.code() == sc\n            assert e.value.details() == ex.details\n            seen_codes.add(sc)\n            seen_details.add(ex.details)\n\n    assert seen_codes == all_status_codes\n    assert len(seen_details) == len(all_status_codes)\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_exception_in_streaming_response(aio):\n    \"\"\"Exceptions are raised correctly from streaming responses.\"\"\"\n    interceptors = _get_interceptors(aio)\n    with dummy_client(\n        special_cases={\"error\": raises(gx.NotFound(\"not found!\"))},\n        interceptors=interceptors,\n        aio_server=aio,\n    ) as client:\n        with pytest.raises(grpc.RpcError) as e:\n            list(client.ExecuteServerStream(DummyRequest(input=\"error\")))\n        assert e.value.code() == grpc.StatusCode.NOT_FOUND\n        assert e.value.details() == \"not found!\"\n\n\ndef _snake_to_camel(s: str) -> str:\n    return \"\".join([p.title() for p in s.split(\"_\")])\n\n\ndef test_not_ok():\n    \"\"\"We cannot create a GrpcException with an OK status code.\"\"\"\n    with pytest.raises(ValueError):\n        gx.GrpcException(status_code=grpc.StatusCode.OK)\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_extending(aio):\n    \"\"\"We can extend ExceptionToStatusInterceptor.\"\"\"\n    interceptor = (\n        AsyncExtendedExceptionToStatusInterceptor()\n        if aio\n        else ExtendedExceptionToStatusInterceptor()\n    )\n    special_cases = {\"error\": raises(NonGrpcException())}\n    with dummy_client(\n        special_cases=special_cases, interceptors=[interceptor], aio_server=aio\n    ) as client:\n        assert (\n            client.Execute(DummyRequest(input=\"foo\")).output == \"foo\"\n        )  # Test a happy path too\n        assert not interceptor.caught_custom_exception\n        with pytest.raises(grpc.RpcError) as e:\n            client.Execute(DummyRequest(input=\"error\"))\n        assert e.value.code() == NonGrpcException.TEST_STATUS_CODE\n        assert e.value.details() == NonGrpcException.TEST_DETAILS\n        assert interceptor.caught_custom_exception\n"
  },
  {
    "path": "tests/test_server.py",
    "content": "\"\"\"Test cases for the grpc-interceptor base ServerInterceptor.\"\"\"\n\nfrom collections import defaultdict\n\nimport grpc\nimport pytest\n\nfrom grpc_interceptor import (\n    AsyncServerInterceptor,\n    MethodName,\n    parse_method_name,\n    ServerInterceptor,\n)\nfrom grpc_interceptor.testing import dummy_client, DummyRequest\nfrom grpc_interceptor.testing.dummy_client import dummy_channel\n\n\nclass CountingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that counts calls and exceptions.\"\"\"\n\n    def __init__(self):\n        self.num_calls = defaultdict(int)\n        self.num_errors = defaultdict(int)\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Count each call and exception.\"\"\"\n        self.num_calls[method_name] += 1\n        try:\n            return method(request, context)\n        except Exception:\n            self.num_errors[method_name] += 1\n            raise\n\n\nclass AsyncCountingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that counts calls and exceptions.\"\"\"\n\n    def __init__(self):\n        self.num_calls = defaultdict(int)\n        self.num_errors = defaultdict(int)\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Count each call and exception.\"\"\"\n        self.num_calls[method_name] += 1\n        try:\n            return await method(request, context)\n        except Exception:\n            self.num_errors[method_name] += 1\n            raise\n\n\nclass SideEffectInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that calls a function for the side effect.\"\"\"\n\n    def __init__(self, side_effect):\n        self._side_effect = side_effect\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Call the side effect and then the RPC method.\"\"\"\n        self._side_effect()\n        return method(request, context)\n\n\nclass AsyncSideEffectInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that calls a function for the side effect.\"\"\"\n\n    def __init__(self, side_effect):\n        self._side_effect = side_effect\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Call the side effect and then the RPC method.\"\"\"\n        self._side_effect()\n        return await method(request, context)\n\n\nclass UppercasingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that modifies the request by uppercasing the input field.\"\"\"\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Uppercases request.input.\"\"\"\n        request.input = request.input.upper()\n        return method(request, context)\n\n\nclass AsyncUppercasingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that modifies the request by uppercasing the input field.\"\"\"\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Uppercases request.input.\"\"\"\n        request.input = request.input.upper()\n        return await method(request, context)\n\n\nclass AbortingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that aborts before calling the handler.\"\"\"\n\n    def __init__(self, message):\n        self._message = message\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Calls abort.\"\"\"\n        context.abort(grpc.StatusCode.ABORTED, self._message)\n\n\nclass AsyncAbortingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that aborts before calling the handler.\"\"\"\n\n    def __init__(self, message):\n        self._message = message\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Calls abort.\"\"\"\n        await context.abort(grpc.StatusCode.ABORTED, self._message)\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_call_counts(aio):\n    \"\"\"The counts should be correct.\"\"\"\n    intr_type = AsyncCountingInterceptor if aio else CountingInterceptor\n    intr = intr_type()\n    interceptors = [intr]\n\n    special_cases = {\"error\": lambda r, c: 1 / 0}\n    with dummy_client(\n        special_cases=special_cases, interceptors=interceptors, aio_server=aio\n    ) as client:\n        assert client.Execute(DummyRequest(input=\"foo\")).output == \"foo\"\n        assert len(intr.num_calls) == 1\n        assert intr.num_calls[\"/DummyService/Execute\"] == 1\n        assert len(intr.num_errors) == 0\n\n        with pytest.raises(grpc.RpcError):\n            client.Execute(DummyRequest(input=\"error\"))\n\n        assert len(intr.num_calls) == 1\n        assert intr.num_calls[\"/DummyService/Execute\"] == 2\n        assert len(intr.num_errors) == 1\n        assert intr.num_errors[\"/DummyService/Execute\"] == 1\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_interceptor_chain(aio):\n    \"\"\"Interceptors are called in the right order.\"\"\"\n    trace = []\n    intr_type = AsyncSideEffectInterceptor if aio else SideEffectInterceptor\n    interceptor1 = intr_type(lambda: trace.append(1))\n    interceptor2 = intr_type(lambda: trace.append(2))\n    with dummy_client(\n        special_cases={}, interceptors=[interceptor1, interceptor2], aio_server=aio\n    ) as client:\n        assert client.Execute(DummyRequest(input=\"test\")).output == \"test\"\n        assert trace == [1, 2]\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_modifying_interceptor(aio):\n    \"\"\"Interceptors can modify requests.\"\"\"\n    intr_type = AsyncUppercasingInterceptor if aio else UppercasingInterceptor\n    interceptor = intr_type()\n    with dummy_client(\n        special_cases={}, interceptors=[interceptor], aio_server=aio\n    ) as client:\n        assert client.Execute(DummyRequest(input=\"test\")).output == \"TEST\"\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_aborting_interceptor(aio):\n    \"\"\"context.abort called in an interceptor works.\"\"\"\n    intr_type = AsyncAbortingInterceptor if aio else AbortingInterceptor\n    interceptor = intr_type(\"oh no\")\n    with dummy_client(\n        special_cases={}, interceptors=[interceptor], aio_server=aio\n    ) as client:\n        with pytest.raises(grpc.RpcError) as e:\n            client.Execute(DummyRequest(input=\"test\"))\n        assert e.value.code() == grpc.StatusCode.ABORTED\n        assert e.value.details() == \"oh no\"\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_method_not_found(aio):\n    \"\"\"Calling undefined endpoints should return Unimplemented.\n\n    Interceptors are not invoked when the RPC call is not handled.\n    \"\"\"\n    intr_type = AsyncCountingInterceptor if aio else CountingInterceptor\n    intr = intr_type()\n    interceptors = [intr]\n\n    with dummy_channel(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as channel:\n        with pytest.raises(grpc.RpcError) as e:\n            channel.unary_unary(\n                \"/DummyService/Unimplemented\",\n            )(b\"\")\n        assert e.value.code() == grpc.StatusCode.UNIMPLEMENTED\n        assert len(intr.num_calls) == 0\n        assert len(intr.num_errors) == 0\n\n\ndef test_method_name():\n    \"\"\"Fields are correct and fully_qualified_service work.\"\"\"\n    mn = MethodName(\"foo.bar\", \"SearchService\", \"Search\")\n    assert mn.package == \"foo.bar\"\n    assert mn.service == \"SearchService\"\n    assert mn.method == \"Search\"\n    assert mn.fully_qualified_service == \"foo.bar.SearchService\"\n\n\ndef test_empty_package_method_name():\n    \"\"\"fully_qualified_service works when there's no package.\"\"\"\n    mn = MethodName(\"\", \"SearchService\", \"Search\")\n    assert mn.fully_qualified_service == \"SearchService\"\n\n\ndef test_parse_method_name():\n    \"\"\"parse_method_name parses fields when there's a package.\"\"\"\n    mn = parse_method_name(\"/foo.bar.SearchService/Search\")\n    assert mn.package == \"foo.bar\"\n    assert mn.service == \"SearchService\"\n    assert mn.method == \"Search\"\n\n\ndef test_parse_empty_package():\n    \"\"\"parse_method_name works with no package.\"\"\"\n    mn = parse_method_name(\"/SearchService/Search\")\n    assert mn.package == \"\"\n    assert mn.service == \"SearchService\"\n    assert mn.method == \"Search\"\n"
  },
  {
    "path": "tests/test_streaming.py",
    "content": "\"\"\"Test cases for streaming RPCs.\"\"\"\n\nimport sys\n\nimport grpc\nimport pytest\n\nfrom grpc_interceptor import AsyncServerInterceptor, ServerInterceptor\nfrom grpc_interceptor.testing import dummy_client, DummyRequest\n\n\nclass StreamingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that streams.\"\"\"\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Doesn't do anything; just make sure we handle streaming RPCs.\"\"\"\n        return method(request, context)\n\n\nclass AsyncStreamingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that streams.\"\"\"\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Doesn't do anything; just make sure we handle streaming RPCs.\"\"\"\n        response_or_iterator = method(request, context)\n        if hasattr(response_or_iterator, \"__aiter__\"):\n            return response_or_iterator\n        else:\n            return await response_or_iterator\n\n\nclass ServerStreamingLoggingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that logs a stream of server responses.\"\"\"\n\n    def __init__(self):\n        self._logs = []\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Log each response object and re-yield.\"\"\"\n        for resp in method(request, context):\n            self._logs.append(resp.output)\n            yield resp\n\n\nclass AsyncServerStreamingLoggingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that logs a stream of server responses.\"\"\"\n\n    def __init__(self):\n        self._logs = []\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Log each response object and re-yield.\"\"\"\n        async for resp in method(request, context):\n            self._logs.append(resp.output)\n            yield resp\n\n\nclass ServerOmniLoggingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that logs both unary and streaming server responses.\"\"\"\n\n    def __init__(self):\n        self._logs = []\n\n    def _log_and_yield(self, iterator):\n        logs = []\n        for resp in iterator:\n            logs.append(resp.output)\n            yield resp\n        self._logs.append(logs)\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Log each response object and re-yield.\"\"\"\n        response_or_iterator = method(request, context)\n        if hasattr(response_or_iterator, \"__iter__\"):\n            return self._log_and_yield(response_or_iterator)\n        else:\n            self._logs.append(response_or_iterator.output)\n            return response_or_iterator\n\n\nclass AsyncServerOmniLoggingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that logs both unary and streaming server responses.\"\"\"\n\n    def __init__(self):\n        self._logs = []\n\n    async def _log_and_yield(self, iterator):\n        logs = []\n        async for resp in iterator:\n            logs.append(resp.output)\n            yield resp\n        self._logs.append(logs)\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Log each response object and re-yield.\"\"\"\n        response_or_iterator = method(request, context)\n        if hasattr(response_or_iterator, \"__aiter__\"):\n            return self._log_and_yield(response_or_iterator)\n        else:\n            response_or_iterator = await response_or_iterator\n            self._logs.append(response_or_iterator.output)\n            return response_or_iterator\n\n\nclass ClientStreamingLoggingInterceptor(ServerInterceptor):\n    \"\"\"A test interceptor that logs a stream of server requests.\"\"\"\n\n    def __init__(self):\n        self._logs = []\n\n    def _log_and_yield(self, request):\n        for r in request:\n            self._logs.append(r.input)\n            yield r\n\n    def intercept(self, method, request, context, method_name):\n        \"\"\"Log each request object and pass through.\"\"\"\n        req = self._log_and_yield(request)\n        return method(req, context)\n\n\nclass AsyncClientStreamingLoggingInterceptor(AsyncServerInterceptor):\n    \"\"\"A test interceptor that logs a stream of server requests.\"\"\"\n\n    def __init__(self):\n        self._logs = []\n\n    async def _log_and_yield(self, request):\n        async for r in request:\n            self._logs.append(r.input)\n            yield r\n\n    async def intercept(self, method, request, context, method_name):\n        \"\"\"Log each request object and pass through.\"\"\"\n        req = self._log_and_yield(request)\n        return await method(req, context)\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\n@pytest.mark.parametrize(\"aio_rw\", [False, True])\ndef test_client_streaming(aio, aio_rw):\n    \"\"\"Client streaming should work.\"\"\"\n    intr = AsyncStreamingInterceptor() if aio else StreamingInterceptor()\n    interceptors = [intr]\n    special_cases = {\"error\": lambda r, c: 1 / 0}\n    with dummy_client(\n        special_cases=special_cases,\n        interceptors=interceptors,\n        aio_server=aio,\n        aio_read_write=aio_rw,\n    ) as client:\n        inputs = [\"foo\", \"bar\"]\n        input_iter = (DummyRequest(input=input) for input in inputs)\n        assert client.ExecuteClientStream(input_iter).output == \"foobar\"\n\n        inputs = [\"foo\", \"error\"]\n        input_iter = (DummyRequest(input=input) for input in inputs)\n        with pytest.raises(grpc.RpcError):\n            client.ExecuteClientStream(input_iter)\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_server_streaming(aio):\n    \"\"\"Server streaming should work.\"\"\"\n    intr = AsyncStreamingInterceptor() if aio else StreamingInterceptor()\n    interceptors = [intr]\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as client:\n        output = [\n            r.output for r in client.ExecuteServerStream(DummyRequest(input=\"foo\"))\n        ]\n        assert output == [\"f\", \"o\", \"o\"]\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_client_server_streaming(aio):\n    \"\"\"Bidirectional streaming should work.\"\"\"\n    intr = AsyncStreamingInterceptor() if aio else StreamingInterceptor()\n    interceptors = [intr]\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as client:\n        inputs = [\"foo\", \"bar\"]\n        input_iter = (DummyRequest(input=input) for input in inputs)\n        response = client.ExecuteClientServerStream(input_iter)\n        assert [r.output for r in response] == inputs\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_interceptor_iterates_server_streaming(aio):\n    \"\"\"The iterator should be able to iterate over streamed server responses.\"\"\"\n    intr = (\n        AsyncServerStreamingLoggingInterceptor()\n        if aio\n        else ServerStreamingLoggingInterceptor()\n    )\n    interceptors = [intr]\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as client:\n        output = [\n            r.output for r in client.ExecuteServerStream(DummyRequest(input=\"foo\"))\n        ]\n        assert output == [\"f\", \"o\", \"o\"]\n        assert intr._logs == [\"f\", \"o\", \"o\"]\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_interceptor_handles_both_unary_and_streaming(aio):\n    \"\"\"The iterator should be able to iterate over streamed server responses.\"\"\"\n    intr = (\n        AsyncServerOmniLoggingInterceptor() if aio else ServerOmniLoggingInterceptor()\n    )\n    interceptors = [intr]\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as client:\n        output = [\n            r.output for r in client.ExecuteServerStream(DummyRequest(input=\"foo\"))\n        ]\n        assert output == [\"f\", \"o\", \"o\"]\n        assert intr._logs == [[\"f\", \"o\", \"o\"]]\n\n        r = client.Execute(DummyRequest(input=\"bar\"))\n        assert r.output == \"bar\"\n        assert intr._logs == [[\"f\", \"o\", \"o\"], \"bar\"]\n\n\n@pytest.mark.parametrize(\"aio\", [False, True])\ndef test_client_log_streaming(aio):\n    \"\"\"Client streaming should work when re-yielding.\"\"\"\n    intr = (\n        AsyncClientStreamingLoggingInterceptor()\n        if aio\n        else ClientStreamingLoggingInterceptor()\n    )\n    interceptors = [intr]\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=aio\n    ) as client:\n        inputs = [\"foo\", \"bar\"]\n        input_iter = (DummyRequest(input=input) for input in inputs)\n        assert client.ExecuteClientStream(input_iter).output == \"foobar\"\n        assert intr._logs == inputs\n\n\n@pytest.mark.skipif(sys.version_info < (3, 7), reason=\"requires Python 3.7\")\n@pytest.mark.asyncio\nasync def test_client_streaming_write_method():\n    \"\"\"Client streaming should work when using write().\"\"\"\n    intr = AsyncClientStreamingLoggingInterceptor()\n    interceptors = [intr]\n    with dummy_client(\n        special_cases={}, interceptors=interceptors, aio_server=True, aio_client=True\n    ) as client:\n        call = client.ExecuteClientStream()\n        await call.write(DummyRequest(input=\"foo\"))\n        await call.write(DummyRequest(input=\"bar\"))\n        await call.done_writing()\n        response = await call\n        assert response.output == \"foobar\"\n        assert intr._logs == [\"foo\", \"bar\"]\n"
  }
]