Repository: Kludex/starlette Branch: main Commit: 9ee951980bae Files: 120 Total size: 817.8 KB Directory structure: gitextract_gn9w6fxg/ ├── .github/ │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE/ │ │ ├── 1-issue.md │ │ └── config.yml │ ├── dependabot.yml │ ├── pull_request_template.md │ └── workflows/ │ ├── main.yml │ └── publish.yml ├── .gitignore ├── CITATION.cff ├── LICENSE.md ├── README.md ├── docs/ │ ├── CNAME │ ├── applications.md │ ├── authentication.md │ ├── background.md │ ├── config.md │ ├── contributing.md │ ├── css/ │ │ └── custom.css │ ├── database.md │ ├── endpoints.md │ ├── exceptions.md │ ├── graphql.md │ ├── index.md │ ├── js/ │ │ └── custom.js │ ├── lifespan.md │ ├── middleware.md │ ├── overrides/ │ │ ├── main.html │ │ └── partials/ │ │ └── toc-item.html │ ├── release-notes.md │ ├── requests.md │ ├── responses.md │ ├── routing.md │ ├── schemas.md │ ├── server-push.md │ ├── staticfiles.md │ ├── templates.md │ ├── testclient.md │ ├── third-party-packages.md │ ├── threadpool.md │ └── websockets.md ├── mkdocs.yml ├── pyproject.toml ├── scripts/ │ ├── README.md │ ├── build │ ├── check │ ├── coverage │ ├── docs │ ├── install │ ├── lint │ ├── sync-version │ └── test ├── starlette/ │ ├── __init__.py │ ├── _exception_handler.py │ ├── _utils.py │ ├── applications.py │ ├── authentication.py │ ├── background.py │ ├── concurrency.py │ ├── config.py │ ├── convertors.py │ ├── datastructures.py │ ├── endpoints.py │ ├── exceptions.py │ ├── formparsers.py │ ├── middleware/ │ │ ├── __init__.py │ │ ├── authentication.py │ │ ├── base.py │ │ ├── cors.py │ │ ├── errors.py │ │ ├── exceptions.py │ │ ├── gzip.py │ │ ├── httpsredirect.py │ │ ├── sessions.py │ │ ├── trustedhost.py │ │ └── wsgi.py │ ├── py.typed │ ├── requests.py │ ├── responses.py │ ├── routing.py │ ├── schemas.py │ ├── staticfiles.py │ ├── status.py │ ├── templating.py │ ├── testclient.py │ ├── types.py │ └── websockets.py └── tests/ ├── __init__.py ├── conftest.py ├── middleware/ │ ├── __init__.py │ ├── test_base.py │ ├── test_cors.py │ ├── test_errors.py │ ├── test_gzip.py │ ├── test_https_redirect.py │ ├── test_middleware.py │ ├── test_session.py │ ├── test_trusted_host.py │ └── test_wsgi.py ├── statics/ │ └── example.txt ├── test__utils.py ├── test_applications.py ├── test_authentication.py ├── test_background.py ├── test_concurrency.py ├── test_config.py ├── test_convertors.py ├── test_datastructures.py ├── test_endpoints.py ├── test_exceptions.py ├── test_formparsers.py ├── test_requests.py ├── test_responses.py ├── test_routing.py ├── test_schemas.py ├── test_staticfiles.py ├── test_status.py ├── test_templates.py ├── test_testclient.py ├── test_websockets.py └── types.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ github: Kludex ================================================ FILE: .github/ISSUE_TEMPLATE/1-issue.md ================================================ --- name: Issue about: Please only raise an issue if you've been advised to do so after discussion. Thanks! 🙏 --- The starting point for issues should usually be a discussion... https://github.com/Kludex/starlette/discussions Possible bugs may be raised as a "Potential Issue" discussion, feature requests may be raised as an "Ideas" discussion. We can then determine if the discussion needs to be escalated into an "Issue" or not. This will help us ensure that the "Issues" list properly reflects ongoing or needed work on the project. --- - [ ] Initially raised as discussion #... ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ # Ref: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository#configuring-the-template-chooser blank_issues_enabled: false contact_links: - name: Discussions url: https://github.com/Kludex/starlette/discussions about: > The "Discussions" forum is where you want to start. 💖 - name: Chat url: https://discord.gg/SWU73HffbV about: > Our community chat forum. ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "uv" directory: "/" schedule: interval: "monthly" groups: python-packages: patterns: - "*" - package-ecosystem: "github-actions" directory: "/" schedule: interval: monthly groups: github-actions: patterns: - "*" ================================================ FILE: .github/pull_request_template.md ================================================ # Summary # Checklist - [ ] I understand that this PR may be closed in case there was no previous discussion. (This doesn't apply to typos!) - [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. - [ ] I've updated the documentation accordingly. ================================================ FILE: .github/workflows/main.yml ================================================ --- name: Test Suite on: push: branches: ["main"] pull_request: branches: ["main"] jobs: tests: name: "Python ${{ matrix.python-version }}" runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install uv uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: python-version: ${{ matrix.python-version }} enable-cache: true - name: Install dependencies run: scripts/install - name: Run linting checks run: scripts/check if: ${{ matrix.python-version != '3.14' }} - name: "Build package & docs" run: scripts/build - name: "Run tests" run: scripts/test - name: "Enforce coverage" run: scripts/coverage # https://github.com/marketplace/actions/alls-green#why check: if: always() needs: [tests] runs-on: ubuntu-latest steps: - name: Decide whether the needed jobs succeeded or failed uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # v1.2.2 with: jobs: ${{ toJSON(needs) }} ================================================ FILE: .github/workflows/publish.yml ================================================ name: Publish on: push: tags: - "*" workflow_dispatch: jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install uv uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1 with: python-version: "3.11" enable-cache: true - name: Install dependencies run: scripts/install - name: Build package & docs run: scripts/build - name: Upload package distributions uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: package-distributions path: dist/ - name: Upload documentation uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: documentation path: site/ pypi-publish: runs-on: ubuntu-latest needs: build if: success() && startsWith(github.ref, 'refs/tags/') permissions: id-token: write environment: name: pypi url: https://pypi.org/project/starlette steps: - name: Download artifacts uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0 with: name: package-distributions path: dist/ - name: Publish distribution 📦 to PyPI uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 docs-publish: runs-on: ubuntu-latest needs: build permissions: contents: read pages: write id-token: write environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} steps: - name: Configure GitHub Pages uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5.0.0 - name: Download artifacts uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0 with: name: documentation path: site/ - name: Upload Pages artifact uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4.0.0 with: path: site - name: Deploy to GitHub Pages uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4.0.5 id: deployment docs-cloudflare: runs-on: ubuntu-latest needs: build environment: name: cloudflare url: https://starlette.dev steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Download artifacts uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0 with: name: documentation path: site/ - uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3.14.1 with: apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} command: > pages deploy ./site --project-name starlette --commit-hash ${{ github.sha }} --branch main ================================================ FILE: .gitignore ================================================ *.pyc test.db .coverage .pytest_cache/ .mypy_cache/ __pycache__/ htmlcov/ site/ *.egg-info/ venv*/ .venv/ .python-version build/ dist/ ================================================ FILE: CITATION.cff ================================================ # This CITATION.cff file was generated with cffinit. # Visit https://bit.ly/cffinit to generate yours today! cff-version: 1.2.0 title: Starlette message: >- If you use this software, please cite it using the metadata from this file. type: software authors: - given-names: Marcelo family-names: Trylesinski email: marcelotryle@gmail.com - given-names: Tom family-names: Christie email: tom@tomchristie.com repository-code: "https://github.com/Kludex/starlette" url: "https://starlette.dev/" abstract: Starlette is an ASGI web framework for Python. keywords: - asgi - starlette license: BSD-3-Clause ================================================ FILE: LICENSE.md ================================================ Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================

starlette-logo

✨ The little ASGI framework that shines. ✨

--- [![Build Status](https://github.com/Kludex/starlette/workflows/Test%20Suite/badge.svg)](https://github.com/Kludex/starlette/actions) [![Package version](https://badge.fury.io/py/starlette.svg)](https://pypi.python.org/pypi/starlette) [![Supported Python Version](https://img.shields.io/pypi/pyversions/starlette.svg?color=%2334D058)](https://pypi.org/project/starlette) [![Discord](https://img.shields.io/discord/1051468649518616576?logo=discord&logoColor=ffffff&color=7389D8&labelColor=6A7EC2)](https://discord.gg/RxKUF5JuHs) --- **Documentation**: https://starlette.dev **Source Code**: https://github.com/Kludex/starlette --- # Starlette Starlette is a lightweight [ASGI][asgi] framework/toolkit, which is ideal for building async web services in Python. It is production-ready, and gives you the following: * A lightweight, low-complexity HTTP web framework. * WebSocket support. * In-process background tasks. * Startup and shutdown events. * Test client built on `httpx`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. * Great overall performance [against independent benchmarks][techempower]. ## Installation ```shell $ pip install starlette ``` You'll also want to install an ASGI server, such as [uvicorn](https://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://hypercorn.readthedocs.io/en/latest/). ```shell $ pip install uvicorn ``` ## Example ```python title="main.py" from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route async def homepage(request): return JSONResponse({'hello': 'world'}) routes = [ Route("/", endpoint=homepage) ] app = Starlette(debug=True, routes=routes) ``` Then run the application using Uvicorn: ```shell $ uvicorn main:app ``` ## Dependencies Starlette only requires `anyio`, and the following are optional: * [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip install starlette[full]`. ## Framework or Toolkit Starlette is designed to be used either as a complete framework, or as an ASGI toolkit. You can use any of its components independently. ```python from starlette.responses import PlainTextResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` Run the `app` application in `example.py`: ```shell $ uvicorn example:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` Run uvicorn with `--reload` to enable auto-reloading on code changes. ## Modularity The modularity that Starlette is designed on promotes building re-usable components that can be shared between any ASGI framework. This should enable an ecosystem of shared middleware and mountable applications. The clean API separation also means it's easier to understand each component in isolation. ---

Starlette is BSD licensed code.
Designed & crafted with care.

— ⭐️ —

[asgi]: https://asgi.readthedocs.io/en/latest/ [httpx]: https://www.python-httpx.org/ [jinja2]: https://jinja.palletsprojects.com/ [python-multipart]: https://multipart.fastapiexpert.com/ [itsdangerous]: https://itsdangerous.palletsprojects.com/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation [techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf ================================================ FILE: docs/CNAME ================================================ www.starlette.io ================================================ FILE: docs/applications.md ================================================ ??? abstract "API Reference" ::: starlette.applications.Starlette options: parameter_headings: false show_root_heading: true heading_level: 3 filters: - "__init__" Starlette includes an application class `Starlette` that nicely ties together all of its other functionality. ```python from contextlib import asynccontextmanager from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.routing import Route, Mount, WebSocketRoute from starlette.staticfiles import StaticFiles def homepage(request): return PlainTextResponse('Hello, world!') def user_me(request): username = "John Doe" return PlainTextResponse('Hello, %s!' % username) def user(request): username = request.path_params['username'] return PlainTextResponse('Hello, %s!' % username) async def websocket_endpoint(websocket): await websocket.accept() await websocket.send_text('Hello, websocket!') await websocket.close() @asynccontextmanager async def lifespan(app): print('Startup') yield print('Shutdown') routes = [ Route('/', homepage), Route('/user/me', user_me), Route('/user/{username}', user), WebSocketRoute('/ws', websocket_endpoint), Mount('/static', StaticFiles(directory="static")), ] app = Starlette(debug=True, routes=routes, lifespan=lifespan) ``` ### Storing state on the app instance You can store arbitrary extra state on the application instance, using the generic `app.state` attribute. For example: ```python app.state.ADMIN_EMAIL = 'admin@example.org' ``` ### Accessing the app instance Where a `request` is available (i.e. endpoints and middleware), the app is available on `request.app`. ================================================ FILE: docs/authentication.md ================================================ Starlette offers a simple but powerful interface for handling authentication and permissions. Once you've installed `AuthenticationMiddleware` with an appropriate authentication backend the `request.user` and `request.auth` interfaces will be available in your endpoints. ```python from starlette.applications import Starlette from starlette.authentication import ( AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser ) from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route import base64 import binascii class BasicAuthBackend(AuthenticationBackend): async def authenticate(self, conn): if "Authorization" not in conn.headers: return auth = conn.headers["Authorization"] try: scheme, credentials = auth.split() if scheme.lower() != 'basic': return decoded = base64.b64decode(credentials).decode("ascii") except (ValueError, UnicodeDecodeError, binascii.Error) as exc: raise AuthenticationError('Invalid basic auth credentials') username, _, password = decoded.partition(":") # TODO: You'd want to verify the username and password here. return AuthCredentials(["authenticated"]), SimpleUser(username) async def homepage(request): if request.user.is_authenticated: return PlainTextResponse('Hello, ' + request.user.display_name) return PlainTextResponse('Hello, you') routes = [ Route("/", endpoint=homepage) ] middleware = [ Middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) ] app = Starlette(routes=routes, middleware=middleware) ``` ## Users Once `AuthenticationMiddleware` is installed the `request.user` interface will be available to endpoints or other middleware. This interface should subclass `BaseUser`, which provides two properties, as well as whatever other information your user model includes. * `.is_authenticated` * `.display_name` Starlette provides two built-in user implementations: `UnauthenticatedUser()`, and `SimpleUser(username)`. ## AuthCredentials It is important that authentication credentials are treated as separate concept from users. An authentication scheme should be able to restrict or grant particular privileges independently of the user identity. The `AuthCredentials` class provides the basic interface that `request.auth` exposes: * `.scopes` ## Permissions Permissions are implemented as an endpoint decorator, that enforces that the incoming request includes the required authentication scopes. ```python from starlette.authentication import requires @requires('authenticated') async def dashboard(request): ... ``` You can include either one or multiple required scopes: ```python from starlette.authentication import requires @requires(['authenticated', 'admin']) async def dashboard(request): ... ``` By default 403 responses will be returned when permissions are not granted. In some cases you might want to customize this, for example to hide information about the URL layout from unauthenticated users. ```python from starlette.authentication import requires @requires(['authenticated', 'admin'], status_code=404) async def dashboard(request): ... ``` !!! note The `status_code` parameter is not supported with WebSockets. The 403 (Forbidden) status code will always be used for those. Alternatively you might want to redirect unauthenticated users to a different page. ```python from starlette.authentication import requires async def homepage(request): ... @requires('authenticated', redirect='homepage') async def dashboard(request): ... ``` When redirecting users, the page you redirect them to will include URL they originally requested at the `next` query param: ```python from starlette.authentication import requires from starlette.responses import RedirectResponse @requires('authenticated', redirect='login') async def admin(request): ... async def login(request): if request.method == "POST": # Now that the user is authenticated, # we can send them to their original request destination if request.user.is_authenticated: next_url = request.query_params.get("next") if next_url: return RedirectResponse(next_url) return RedirectResponse("/") ``` For class-based endpoints, you should wrap the decorator around a method on the class. ```python from starlette.authentication import requires from starlette.endpoints import HTTPEndpoint class Dashboard(HTTPEndpoint): @requires("authenticated") async def get(self, request): ... ``` ## Custom authentication error responses You can customise the error response sent when a `AuthenticationError` is raised by an auth backend: ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import JSONResponse def on_auth_error(request: Request, exc: Exception): return JSONResponse({"error": str(exc)}, status_code=401) app = Starlette( middleware=[ Middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error), ], ) ``` ================================================ FILE: docs/background.md ================================================ Starlette includes a `BackgroundTask` class for in-process background tasks. A background task should be attached to a response, and will run only once the response has been sent. ### Background Task Used to add a single background task to a response. Signature: `BackgroundTask(func, *args, **kwargs)` ```python from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route from starlette.background import BackgroundTask ... async def signup(request): data = await request.json() username = data['username'] email = data['email'] task = BackgroundTask(send_welcome_email, to_address=email) message = {'status': 'Signup successful'} return JSONResponse(message, background=task) async def send_welcome_email(to_address): ... routes = [ ... Route('/user/signup', endpoint=signup, methods=['POST']) ] app = Starlette(routes=routes) ``` ### BackgroundTasks Used to add multiple background tasks to a response. Signature: `BackgroundTasks(tasks=[])` ```python from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.background import BackgroundTasks async def signup(request): data = await request.json() username = data['username'] email = data['email'] tasks = BackgroundTasks() tasks.add_task(send_welcome_email, to_address=email) tasks.add_task(send_admin_notification, username=username) message = {'status': 'Signup successful'} return JSONResponse(message, background=tasks) async def send_welcome_email(to_address): ... async def send_admin_notification(username): ... routes = [ Route('/user/signup', endpoint=signup, methods=['POST']) ] app = Starlette(routes=routes) ``` !!! important The tasks are executed in order. In case one of the tasks raises an exception, the following tasks will not get the opportunity to be executed. ================================================ FILE: docs/config.md ================================================ Starlette encourages a strict separation of configuration from code, following [the twelve-factor pattern][twelve-factor]. Configuration should be stored in environment variables, or in a `.env` file that is not committed to source control. ```python title="main.py" from sqlalchemy import create_engine from starlette.applications import Starlette from starlette.config import Config from starlette.datastructures import CommaSeparatedStrings, Secret # Config will be read from environment variables and/or ".env" files. config = Config(".env") DEBUG = config('DEBUG', cast=bool, default=False) DATABASE_URL = config('DATABASE_URL') SECRET_KEY = config('SECRET_KEY', cast=Secret) ALLOWED_HOSTS = config('ALLOWED_HOSTS', cast=CommaSeparatedStrings) app = Starlette(debug=DEBUG) engine = create_engine(DATABASE_URL) ... ``` ```shell title=".env" # Don't commit this to source control. # Eg. Include ".env" in your `.gitignore` file. DEBUG=True DATABASE_URL=postgresql://user:password@localhost:5432/database SECRET_KEY=43n080musdfjt54t-09sdgr ALLOWED_HOSTS=127.0.0.1, localhost ``` ## Configuration precedence The order in which configuration values are read is: * From an environment variable. * From the `.env` file. * The default value given in `config`. If none of those match, then `config(...)` will raise an error. ## Secrets For sensitive keys, the `Secret` class is useful, since it helps minimize occasions where the value it holds could leak out into tracebacks or other code introspection. To get the value of a `Secret` instance, you must explicitly cast it to a string. You should only do this at the point at which the value is used. ```python >>> from myproject import settings >>> settings.SECRET_KEY Secret('**********') >>> str(settings.SECRET_KEY) '98n349$%8b8-7yjn0n8y93T$23r' ``` !!! tip You can use `DatabaseURL` from `databases` package [here](https://github.com/encode/databases/blob/ab5eb718a78a27afe18775754e9c0fa2ad9cd211/databases/core.py#L420) to store database URLs and avoid leaking them in the logs. ## CommaSeparatedStrings For holding multiple inside a single config key, the `CommaSeparatedStrings` type is useful. ```python >>> from myproject import settings >>> print(settings.ALLOWED_HOSTS) CommaSeparatedStrings(['127.0.0.1', 'localhost']) >>> print(list(settings.ALLOWED_HOSTS)) ['127.0.0.1', 'localhost'] >>> print(len(settings.ALLOWED_HOSTS)) 2 >>> print(settings.ALLOWED_HOSTS[0]) '127.0.0.1' ``` ## Reading or modifying the environment In some cases you might want to read or modify the environment variables programmatically. This is particularly useful in testing, where you may want to override particular keys in the environment. Rather than reading or writing from `os.environ`, you should use Starlette's `environ` instance. This instance is a mapping onto the standard `os.environ` that additionally protects you by raising an error if any environment variable is set *after* the point that it has already been read by the configuration. If you're using `pytest`, then you can setup any initial environment in `tests/conftest.py`. ```python title="tests/conftest.py" from starlette.config import environ environ['DEBUG'] = 'TRUE' ``` ## Reading prefixed environment variables You can namespace the environment variables by setting `env_prefix` argument. ```python title="myproject/settings.py" import os from starlette.config import Config os.environ['APP_DEBUG'] = 'yes' os.environ['ENVIRONMENT'] = 'dev' config = Config(env_prefix='APP_') DEBUG = config('DEBUG') # lookups APP_DEBUG, returns "yes" ENVIRONMENT = config('ENVIRONMENT') # lookups APP_ENVIRONMENT, raises KeyError as variable is not defined ``` ## Custom encoding for environment files By default, Starlette reads environment files using UTF-8 encoding. You can specify a different encoding by setting `encoding` argument. ```python title="myproject/settings.py" from starlette.config import Config # Using custom encoding for .env file config = Config(".env", encoding="latin-1") ``` ## A full example Structuring large applications can be complex. You need proper separation of configuration and code, database isolation during tests, separate test and production databases, etc... Here we'll take a look at a complete example, that demonstrates how we can start to structure an application. First, let's keep our settings, our database table definitions, and our application logic separated: ```python title="myproject/settings.py" from starlette.config import Config from starlette.datastructures import Secret config = Config(".env") DEBUG = config('DEBUG', cast=bool, default=False) SECRET_KEY = config('SECRET_KEY', cast=Secret) DATABASE_URL = config('DATABASE_URL') ``` ```python title="myproject/tables.py" import sqlalchemy # Database table definitions. metadata = sqlalchemy.MetaData() organisations = sqlalchemy.Table( ... ) ``` ```python title="myproject/app.py" from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.sessions import SessionMiddleware from starlette.routing import Route from myproject import settings async def homepage(request): ... routes = [ Route("/", endpoint=homepage) ] middleware = [ Middleware( SessionMiddleware, secret_key=settings.SECRET_KEY, ) ] app = Starlette(debug=settings.DEBUG, routes=routes, middleware=middleware) ``` Now let's deal with our test configuration. We'd like to create a new test database every time the test suite runs, and drop it once the tests complete. We'd also like to ensure ```python title="tests/conftest.py" from starlette.config import environ from starlette.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy_utils import create_database, database_exists, drop_database # This line would raise an error if we use it after 'settings' has been imported. environ['DEBUG'] = 'TRUE' from myproject import settings from myproject.app import app from myproject.tables import metadata @pytest.fixture(autouse=True, scope="session") def setup_test_database(): """ Create a clean test database every time the tests are run. """ url = settings.DATABASE_URL engine = create_engine(url) assert not database_exists(url), 'Test database already exists. Aborting tests.' create_database(url) # Create the test database. metadata.create_all(engine) # Create the tables. yield # Run the tests. drop_database(url) # Drop the test database. @pytest.fixture() def client(): """ Make a 'client' fixture available to test cases. """ # Our fixture is created within a context manager. This ensures that # application lifespan runs for every test case. with TestClient(app) as test_client: yield test_client ``` [twelve-factor]: https://12factor.net/config ================================================ FILE: docs/contributing.md ================================================ # Contributing Thank you for being interested in contributing to Starlette. There are many ways you can contribute to the project: - Try Starlette and [report bugs/issues you find](https://github.com/Kludex/starlette/issues/new) - [Implement new features](https://github.com/Kludex/starlette/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) - [Review Pull Requests of others](https://github.com/Kludex/starlette/pulls) - Write documentation - Participate in discussions ## Reporting Bugs or Other Issues Found something that Starlette should support? Stumbled upon some unexpected behaviour? Contributions should generally start out with [a discussion](https://github.com/Kludex/starlette/discussions). Possible bugs may be raised as a "Potential Issue" discussion, feature requests may be raised as an "Ideas" discussion. We can then determine if the discussion needs to be escalated into an "Issue" or not, or if we'd consider a pull request. Try to be more descriptive as you can and in case of a bug report, provide as much information as possible like: - OS platform - Python version - Installed dependencies and versions (`python -m pip freeze`) - Code snippet - Error traceback You should always try to reduce any examples to the *simplest possible case* that demonstrates the issue. ## Development To start developing Starlette, create a **fork** of the [Starlette repository](https://github.com/Kludex/starlette) on GitHub. Then clone your fork with the following command replacing `YOUR-USERNAME` with your GitHub username: ```shell $ git clone https://github.com/YOUR-USERNAME/starlette ``` You can now install the project and its dependencies using: ```shell $ cd starlette $ scripts/install ``` ## Testing and Linting We use custom shell scripts to automate testing, linting, and documentation building workflow. To run the tests, use: ```shell $ scripts/test ``` Any additional arguments will be passed to `pytest`. See the [pytest documentation](https://docs.pytest.org/en/latest/how-to/usage.html) for more information. For example, to run a single test script: ```shell $ scripts/test tests/test_application.py ``` To run the code auto-formatting: ```shell $ scripts/lint ``` Lastly, to run code checks separately (they are also run as part of `scripts/test`), run: ```shell $ scripts/check ``` ## Documenting Documentation pages are located under the `docs/` folder. To run the documentation site locally (useful for previewing changes), use: ```shell $ scripts/docs ``` ## Resolving Build / CI Failures Once you've submitted your pull request, the test suite will automatically run, and the results will show up in GitHub. If the test suite fails, you'll want to click through to the "Details" link, and try to identify why the test suite failed.

Failing PR commit status

Here are some common ways the test suite can fail: ### Check Job Failed

Failing GitHub action lint job

This job failing means there is either a code formatting issue or type-annotation issue. You can look at the job output to figure out why it's failed or within a shell run: ```shell $ scripts/check ``` It may be worth it to run `$ scripts/lint` to attempt auto-formatting the code and if that job succeeds commit the changes. ### Docs Job Failed This job failing means the documentation failed to build. This can happen for a variety of reasons like invalid markdown or missing configuration within `mkdocs.yml`. ### Python 3.X Job Failed

Failing GitHub action test job

This job failing means the unit tests failed or not all code paths are covered by unit tests. If tests are failing you will see this message under the coverage report: `=== 1 failed, 435 passed, 1 skipped, 1 xfailed in 11.09s ===` If tests succeed but coverage doesn't reach our current threshold, you will see this message under the coverage report: `FAIL Required test coverage of 100% not reached. Total coverage: 99.00%` ## Releasing *This section is targeted at Starlette maintainers.* Before releasing a new version, create a pull request that includes: - **An update to the changelog**: - We follow the format from [keepachangelog](https://keepachangelog.com/en/1.0.0/). - [Compare](https://github.com/Kludex/starlette/compare/) `main` with the tag of the latest release, and list all entries that are of interest to our users: - Things that **must** go in the changelog: added, changed, deprecated or removed features, and bug fixes. - Things that **should not** go in the changelog: changes to documentation, tests or tooling. - Try sorting entries in descending order of impact / importance. - Keep it concise and to-the-point. 🎯 - **A version bump**: see `__version__.py`. For an example, see [#1600](https://github.com/Kludex/starlette/pull/1600). Once the release PR is merged, create a [new release](https://github.com/Kludex/starlette/releases/new) including: - Tag version like `0.13.3`. - Release title `Version 0.13.3` - Description copied from the changelog. Once created this release will be automatically uploaded to PyPI. ================================================ FILE: docs/css/custom.css ================================================ /* Lighter dark mode colors */ [data-md-color-scheme="slate"] { --md-default-bg-color: #263238; --md-default-fg-color: #e0e0e0; --md-code-bg-color: #2e3c43; } /* Announcement bar styling */ .announce-wrapper { display: flex; justify-content: center; align-items: center; height: 40px; min-height: 40px; background-color: var(--md-primary-fg-color); } .announce-wrapper #announce-msg { display: flex; align-items: center; justify-content: center; } .announce-wrapper #announce-msg div.item { display: none; } .announce-wrapper #announce-msg div.item:first-child { display: block; } a.announce-link:link, a.announce-link:visited { color: var(--md-primary-bg-color); text-decoration: none; font-weight: 500; } a.announce-link:hover { color: var(--md-accent-fg-color); text-decoration: underline; } ================================================ FILE: docs/database.md ================================================ Starlette is not strictly tied to any particular database implementation. You are free to use any async database library that you prefer. Some popular options include: - [SQLAlchemy](https://www.sqlalchemy.org/) - The Python SQL toolkit with native async support (2.0+). - [SQLModel](https://sqlmodel.tiangolo.com/) - SQL databases in Python, designed for simplicity, built on top of SQLAlchemy and Pydantic. - [Tortoise ORM](https://tortoise.github.io/) - An easy-to-use asyncio ORM inspired by Django. - [Piccolo](https://piccolo-orm.com/) - A fast, user-friendly ORM and query builder. Refer to your chosen database library's documentation for specific connection and query patterns. ================================================ FILE: docs/endpoints.md ================================================ Starlette includes the classes `HTTPEndpoint` and `WebSocketEndpoint` that provide a class-based view pattern for handling HTTP method dispatching and WebSocket sessions. ### HTTPEndpoint The `HTTPEndpoint` class can be used as an ASGI application: ```python from starlette.responses import PlainTextResponse from starlette.endpoints import HTTPEndpoint class App(HTTPEndpoint): async def get(self, request): return PlainTextResponse(f"Hello, world!") ``` If you're using a Starlette application instance to handle routing, you can dispatch to an `HTTPEndpoint` class. Make sure to dispatch to the class itself, rather than to an instance of the class: ```python from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.endpoints import HTTPEndpoint from starlette.routing import Route class Homepage(HTTPEndpoint): async def get(self, request): return PlainTextResponse(f"Hello, world!") class User(HTTPEndpoint): async def get(self, request): username = request.path_params['username'] return PlainTextResponse(f"Hello, {username}") routes = [ Route("/", Homepage), Route("/{username}", User) ] app = Starlette(routes=routes) ``` HTTP endpoint classes will respond with "405 Method not allowed" responses for any request methods which do not map to a corresponding handler. ### WebSocketEndpoint The `WebSocketEndpoint` class is an ASGI application that presents a wrapper around the functionality of a `WebSocket` instance. The ASGI connection scope is accessible on the endpoint instance via `.scope` and has an attribute `encoding` which may optionally be set, in order to validate the expected websocket data in the `on_receive` method. The encoding types are: * `'json'` * `'bytes'` * `'text'` There are three overridable methods for handling specific ASGI websocket message types: * `async def on_connect(websocket, **kwargs)` * `async def on_receive(websocket, data)` * `async def on_disconnect(websocket, close_code)` ```python from starlette.endpoints import WebSocketEndpoint class App(WebSocketEndpoint): encoding = 'bytes' async def on_connect(self, websocket): await websocket.accept() async def on_receive(self, websocket, data): await websocket.send_bytes(b"Message: " + data) async def on_disconnect(self, websocket, close_code): pass ``` The `WebSocketEndpoint` can also be used with the `Starlette` application class: ```python import uvicorn from starlette.applications import Starlette from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint from starlette.responses import HTMLResponse from starlette.routing import Route, WebSocketRoute html = """ Chat

WebSocket Chat

""" class Homepage(HTTPEndpoint): async def get(self, request): return HTMLResponse(html) class Echo(WebSocketEndpoint): encoding = "text" async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") routes = [ Route("/", Homepage), WebSocketRoute("/ws", Echo) ] app = Starlette(routes=routes) ``` ================================================ FILE: docs/exceptions.md ================================================ Starlette allows you to install custom exception handlers to deal with how you return responses when errors or handled exceptions occur. ```python from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import HTMLResponse HTML_404_PAGE = ... HTML_500_PAGE = ... async def not_found(request: Request, exc: HTTPException): return HTMLResponse(content=HTML_404_PAGE, status_code=exc.status_code) async def server_error(request: Request, exc: HTTPException): return HTMLResponse(content=HTML_500_PAGE, status_code=exc.status_code) exception_handlers = { 404: not_found, 500: server_error } app = Starlette(routes=routes, exception_handlers=exception_handlers) ``` If `debug` is enabled and an error occurs, then instead of using the installed 500 handler, Starlette will respond with a traceback response. ```python app = Starlette(debug=True, routes=routes, exception_handlers=exception_handlers) ``` As well as registering handlers for specific status codes, you can also register handlers for classes of exceptions. In particular you might want to override how the built-in `HTTPException` class is handled. For example, to use JSON style responses: ```python async def http_exception(request: Request, exc: HTTPException): return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) exception_handlers = { HTTPException: http_exception } ``` The `HTTPException` is also equipped with the `headers` argument. Which allows the propagation of the headers to the response class: ```python async def http_exception(request: Request, exc: HTTPException): return JSONResponse( {"detail": exc.detail}, status_code=exc.status_code, headers=exc.headers ) ``` You might also want to override how `WebSocketException` is handled: ```python async def websocket_exception(websocket: WebSocket, exc: WebSocketException): await websocket.close(code=1008) exception_handlers = { WebSocketException: websocket_exception } ``` ## Errors and handled exceptions It is important to differentiate between handled exceptions and errors. Handled exceptions do not represent error cases. They are coerced into appropriate HTTP responses, which are then sent through the standard middleware stack. By default the `HTTPException` class is used to manage any handled exceptions. Errors are any other exception that occurs within the application. These cases should bubble through the entire middleware stack as exceptions. Any error logging middleware should ensure that it re-raises the exception all the way up to the server. In practical terms, the error handled used is `exception_handler[500]` or `exception_handler[Exception]`. Both keys `500` and `Exception` can be used. See below: ```python async def handle_error(request: Request, exc: HTTPException): # Perform some logic return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) exception_handlers = { Exception: handle_error # or "500: handle_error" } ``` It's important to notice that in case a [`BackgroundTask`](background.md) raises an exception, it will be handled by the `handle_error` function, but at that point, the response was already sent. In other words, the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then it will use the response object - in the above example, the returned `JSONResponse`. In order to deal with this behaviour correctly, the middleware stack of a `Starlette` application is configured like this: * `ServerErrorMiddleware` - Returns 500 responses when server errors occur. * Installed middleware * `ExceptionMiddleware` - Deals with handled exceptions, and returns responses. * Router * Endpoints ## HTTPException The `HTTPException` class provides a base class that you can use for any handled exceptions. The `ExceptionMiddleware` implementation defaults to returning plain-text HTTP responses for any `HTTPException`. * `HTTPException(status_code, detail=None, headers=None)` You should only raise `HTTPException` inside routing or endpoints. Middleware classes should instead just return appropriate responses directly. You can use an `HTTPException` on a WebSocket endpoint. In case it's raised before `websocket.accept()` the connection is not upgraded to a WebSocket connection, and the proper HTTP response is returned. ```python from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket async def websocket_endpoint(websocket: WebSocket): raise HTTPException(status_code=400, detail="Bad request") app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)]) ``` ## WebSocketException You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints. * `WebSocketException(code=1008, reason=None)` You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). ================================================ FILE: docs/graphql.md ================================================ GraphQL support in Starlette was deprecated in version 0.15.0, and removed in version 0.17.0. Although GraphQL support is no longer built in to Starlette, you can still use GraphQL with Starlette via 3rd party libraries. These libraries all have Starlette-specific guides to help you do just that: - [Ariadne](https://ariadnegraphql.org/docs/starlette-integration.html) - [`starlette-graphene3`](https://github.com/ciscorn/starlette-graphene3#example) - [Strawberry](https://strawberry.rocks/docs/integrations/starlette) - [`tartiflette-asgi`](https://tartiflette.github.io/tartiflette-asgi/usage/#starlette) ================================================ FILE: docs/index.md ================================================

starlette starlette

✨ The little ASGI framework that shines. ✨

Build Status Package version Supported Python versions Discord

--- **Documentation**: https://starlette.dev **Source Code**: https://github.com/Kludex/starlette --- # Introduction Starlette is a lightweight [ASGI][asgi] framework/toolkit, which is ideal for building async web services in Python. It is production-ready, and gives you the following: * A lightweight, low-complexity HTTP web framework. * WebSocket support. * In-process background tasks. * Startup and shutdown events. * Test client built on `httpx`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. * Few hard dependencies. * Compatible with `asyncio` and `trio` backends. * Great overall performance [against independent benchmarks][techempower]. ## Sponsorship Help us keep Starlette maintained and sustainable by [becoming a sponsor](https://github.com/sponsors/Kludex). **Current sponsors:**
FastAPI Hugging Face
## Installation ```shell pip install starlette ``` You'll also want to install an ASGI server, such as [uvicorn](https://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://hypercorn.readthedocs.io/en/latest/). ```shell pip install uvicorn ``` ## Example ```python title="main.py" from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route async def homepage(request): return JSONResponse({'hello': 'world'}) app = Starlette(debug=True, routes=[ Route('/', homepage), ]) ``` Then run the application... ```shell uvicorn main:app ``` ## Dependencies Starlette only requires `anyio`, and the following dependencies are optional: * [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. You can install all of these with `pip install starlette[full]`. ## Framework or Toolkit Starlette is designed to be used either as a complete framework, or as an ASGI toolkit. You can use any of its components independently. ```python title="main.py" from starlette.responses import PlainTextResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` Run the `app` application in `main.py`: ```shell $ uvicorn main:app INFO: Started server process [11509] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) ``` Run uvicorn with `--reload` to enable auto-reloading on code changes. ## Modularity The modularity that Starlette is designed on promotes building re-usable components that can be shared between any ASGI framework. This should enable an ecosystem of shared middleware and mountable applications. The clean API separation also means it's easier to understand each component in isolation. ---

Starlette is BSD licensed code.
Designed & crafted with care.

— ⭐️ —

[asgi]: https://asgi.readthedocs.io/en/latest/ [httpx]: https://www.python-httpx.org/ [jinja2]: https://jinja.palletsprojects.com/ [python-multipart]: https://multipart.fastapiexpert.com/ [itsdangerous]: https://itsdangerous.palletsprojects.com/ [sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation [techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf ================================================ FILE: docs/js/custom.js ================================================ function shuffle(array) { var currentIndex = array.length, temporaryValue, randomIndex; while (0 !== currentIndex) { randomIndex = Math.floor(Math.random() * currentIndex); currentIndex -= 1; temporaryValue = array[currentIndex]; array[currentIndex] = array[randomIndex]; array[randomIndex] = temporaryValue; } return array; } async function showRandomAnnouncement(groupId, timeInterval) { const announceGroup = document.getElementById(groupId); if (announceGroup) { let children = [].slice.call(announceGroup.children); children = shuffle(children) let index = 0 const announceRandom = () => { children.forEach((el, i) => { el.style.display = "none" }); children[index].style.display = "block" index = (index + 1) % children.length } announceRandom() setInterval(announceRandom, timeInterval) } } async function main() { showRandomAnnouncement('announce-msg', 5000) } document$.subscribe(() => { main() }) ================================================ FILE: docs/lifespan.md ================================================ Starlette applications can register a lifespan handler for dealing with code that needs to run before the application starts up, or when the application is shutting down. ```python import contextlib from starlette.applications import Starlette @contextlib.asynccontextmanager async def lifespan(app): async with some_async_resource(): print("Run at startup!") yield print("Run on shutdown!") routes = [ ... ] app = Starlette(routes=routes, lifespan=lifespan) ``` Starlette will not start serving any incoming requests until the lifespan has been run. The lifespan teardown will run once all connections have been closed, and any in-process background tasks have completed. Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html) for managing asynchronous tasks. ## Lifespan State The lifespan has the concept of `state`, which is a dictionary that can be used to share the objects between the lifespan, and the requests. ```python import contextlib from typing import AsyncIterator, TypedDict import httpx from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route class State(TypedDict): http_client: httpx.AsyncClient @contextlib.asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[State]: async with httpx.AsyncClient() as client: yield {"http_client": client} async def homepage(request: Request) -> PlainTextResponse: client = request.state.http_client response = await client.get("https://www.example.com") return PlainTextResponse(response.text) app = Starlette( lifespan=lifespan, routes=[Route("/", homepage)] ) ``` The `state` received on the requests is a **shallow** copy of the state received on the lifespan handler. ## Accessing State The state can be accessed using either attribute-style or dictionary-style syntax. The dictionary-style syntax was introduced in Starlette 0.52.0 (January 2026), with the idea of improving type safety when using the lifespan state, given that `Request` became a generic over the state type. ```python from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import TypedDict import httpx from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route class State(TypedDict): http_client: httpx.AsyncClient @asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[State]: async with httpx.AsyncClient() as client: yield {"http_client": client} async def homepage(request: Request[State]) -> PlainTextResponse: client = request.state["http_client"] reveal_type(client) # Revealed type is 'httpx.AsyncClient' response = await client.get("https://www.example.com") return PlainTextResponse(response.text) app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)]) ``` This also works with WebSockets: ```python async def websocket_endpoint(websocket: WebSocket[State]) -> None: await websocket.accept() client = websocket.state["http_client"] response = await client.get("https://www.example.com") await websocket.send_text(response.text) await websocket.close() app = Starlette(lifespan=lifespan, routes=[WebSocketRoute("/ws", websocket_endpoint)]) ``` !!! note There were many attempts to make this work with attribute-style access instead of dictionary-style access, but none were satisfactory, given they would have been breaking changes, or there were typing limitations. For more details, see: - [@Kludex/starlette#issues/3005](https://github.com/Kludex/starlette/issues/3005) - [@python/typing#discussions/1457](https://github.com/python/typing/discussions/1457) - [@Kludex/starlette#pull/3036](https://github.com/Kludex/starlette/pull/3036) ## Running lifespan in tests You should use `TestClient` as a context manager, to ensure that the lifespan is called. ```python from example import app from starlette.testclient import TestClient def test_homepage(): with TestClient(app) as client: # Application's lifespan is called on entering the block. response = client.get("/") assert response.status_code == 200 # And the lifespan's teardown is run when exiting the block. ``` ================================================ FILE: docs/middleware.md ================================================ Starlette includes several middleware classes for adding behavior that is applied across your entire application. These are all implemented as standard ASGI middleware classes, and can be applied either to Starlette or to any other ASGI application. ## Using middleware The Starlette application class allows you to include the ASGI middleware in a way that ensures that it remains wrapped by the exception handler. ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware routes = ... # Ensure that all requests include an 'example.com' or # '*.example.com' host header, and strictly enforce https-only access. middleware = [ Middleware( TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com'], ), Middleware(HTTPSRedirectMiddleware) ] app = Starlette(routes=routes, middleware=middleware) ``` Every Starlette application automatically includes two pieces of middleware by default: * `ServerErrorMiddleware` - Ensures that application exceptions may return a custom 500 page, or display an application traceback in DEBUG mode. This is *always* the outermost middleware layer. * `ExceptionMiddleware` - Adds exception handlers, so that particular types of expected exception cases can be associated with handler functions. For example raising `HTTPException(status_code=404)` within an endpoint will end up rendering a custom 404 page. Middleware is evaluated from top-to-bottom, so the flow of execution in our example application would look like this: * Middleware * `ServerErrorMiddleware` * `TrustedHostMiddleware` * `HTTPSRedirectMiddleware` * `ExceptionMiddleware` * Routing * Endpoint The following middleware implementations are available in the Starlette package: ## CORSMiddleware Adds appropriate [CORS headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) to outgoing responses in order to allow cross-origin requests from browsers. The default parameters used by the CORSMiddleware implementation are restrictive by default, so you'll need to explicitly enable particular origins, methods, or headers, in order for browsers to be permitted to use them in a Cross-Domain context. ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware routes = ... middleware = [ Middleware(CORSMiddleware, allow_origins=['*']) ] app = Starlette(routes=routes, middleware=middleware) ``` The following arguments are supported: * `allow_origins` - A list of origins that should be permitted to make cross-origin requests. eg. `['https://example.org', 'https://www.example.org']`. You can use `['*']` to allow any origin. * `allow_origin_regex` - A regex string to match against origins that should be permitted to make cross-origin requests. eg. `'https://[a-zA-Z0-9-]+\.example\.org'`. Avoid using `.*` or `.+` as they match URL special characters (`/`, `@`, `#`, `?`) and may result in overly permissive origin matching. Use specific character classes like `[a-zA-Z0-9-]+` instead. * `allow_methods` - A list of HTTP methods that should be allowed for cross-origin requests. Defaults to `['GET']`. You can use `['*']` to allow all standard methods. * `allow_headers` - A list of HTTP request headers that should be supported for cross-origin requests. Defaults to `[]`. You can use `['*']` to allow all headers. The `Accept`, `Accept-Language`, `Content-Language` and `Content-Type` headers are always allowed for CORS requests. * `allow_credentials` - Indicate that cookies should be supported for cross-origin requests. Defaults to `False`. Also, `allow_origins`, `allow_methods` and `allow_headers` cannot be set to `['*']` for credentials to be allowed, all of them must be explicitly specified. * `allow_private_network` - Indicates whether to accept cross-origin requests over a private network. Defaults to `False`. * `expose_headers` - Indicate any response headers that should be made accessible to the browser. Defaults to `[]`. * `max_age` - Sets a maximum time in seconds for browsers to cache CORS responses. Defaults to `600`. The middleware responds to two particular types of HTTP request... #### CORS preflight requests These are any `OPTIONS` request with `Origin` and `Access-Control-Request-Method` headers. In this case the middleware will intercept the incoming request and respond with appropriate CORS headers, and either a 200 or 400 response for informational purposes. #### Simple requests Any request with an `Origin` header. In this case the middleware will pass the request through as normal, but will include appropriate CORS headers on the response. #### Private Network Access (PNA) Private Network Access is a browser security feature that restricts websites from public networks from accessing servers on private networks. When a website attempts to make such a cross-network request, the browser will send a `Access-Control-Request-Private-Network: true` header in the pre-flight request. If the `allow_private_network` flag is set to `True`, the middleware will include the `Access-Control-Allow-Private-Network: true` header in the response, allowing the request. If set to `False`, the middleware will return a 400 response, blocking the request. ### CORSMiddleware Global Enforcement When using CORSMiddleware with your Starlette application, it's important to ensure that CORS headers are applied even to error responses generated by unhandled exceptions. The recommended solution is to wrap the entire Starlette application with CORSMiddleware. This approach guarantees that even if an exception is caught by ServerErrorMiddleware (or other outer error-handling middleware), the response will still include the proper `Access-Control-Allow-Origin` header. For example, instead of adding CORSMiddleware as an inner `middleware` via the Starlette middleware parameter, you can wrap your application as follows: ```python from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware import uvicorn app = Starlette() app = CORSMiddleware(app=app, allow_origins=["*"]) # ... your routes and middleware configuration ... if __name__ == '__main__': uvicorn.run( app, host='0.0.0.0', port=8000 ) ``` ## SessionMiddleware Adds signed cookie-based HTTP sessions. Session information is readable but not modifiable. The session cookie is always set with the `"HttpOnly"` flag, preventing client-side JavaScript from accessing it. Access or modify the session data using the `request.session` dictionary interface. The following arguments are supported: * `secret_key` - Should be a random string. * `session_cookie` - Defaults to "session". * `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session. * `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`. * `path` - The path set for the session cookie. Defaults to `'/'`. * `https_only` - Indicate that the `"Secure"` flag should be set (can be used with HTTPS only). Defaults to `False`. Set this to `True` in production to ensure the session cookie is only sent over HTTPS. * `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains ([reference](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute)). ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.sessions import SessionMiddleware routes = ... middleware = [ Middleware(SessionMiddleware, secret_key=..., https_only=True) ] app = Starlette(routes=routes, middleware=middleware) ``` ## HTTPSRedirectMiddleware Enforces that all incoming requests must either be `https` or `wss`. Any incoming requests to `http` or `ws` will be redirected to the secure scheme instead. ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware routes = ... middleware = [ Middleware(HTTPSRedirectMiddleware) ] app = Starlette(routes=routes, middleware=middleware) ``` There are no configuration options for this middleware class. ## TrustedHostMiddleware Enforces that all incoming requests have a correctly set `Host` header, in order to guard against HTTP Host Header attacks. ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware routes = ... middleware = [ Middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com']) ] app = Starlette(routes=routes, middleware=middleware) ``` The following arguments are supported: * `allowed_hosts` - A list of domain names that should be allowed as hostnames. Wildcard domains such as `*.example.com` are supported for matching subdomains. To allow any hostname either use `allowed_hosts=["*"]` or omit the middleware. * `www_redirect` - If set to True, requests to non-www versions of the allowed hosts will be redirected to their www counterparts. Defaults to `True`. If an incoming request does not validate correctly then a 400 response will be sent. ## GZipMiddleware Handles GZip responses for any request that includes `"gzip"` in the `Accept-Encoding` header. The middleware will handle both standard and streaming responses. ??? info "Buffer on streaming responses" On streaming responses, the middleware will buffer the response before compressing it. The idea is that we don't want to compress every small chunk of data, as it would be inefficient. Instead, we buffer the response until it reaches a certain size, and then compress it. This may cause a delay in the response, as the middleware waits for the buffer to fill up before compressing it. ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware routes = ... middleware = [ Middleware(GZipMiddleware, minimum_size=1000, compresslevel=9) ] app = Starlette(routes=routes, middleware=middleware) ``` The following arguments are supported: * `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`. * `compresslevel` - Used during GZip compression. It is an integer ranging from 1 to 9. Defaults to `9`. Lower value results in faster compression but larger file sizes, while higher value results in slower compression but smaller file sizes. The middleware won't GZip responses that already have either a `Content-Encoding` set, to prevent them from being encoded twice, or a `Content-Type` set to `text/event-stream`, to avoid compressing server-sent events. ## BaseHTTPMiddleware An abstract class that allows you to write ASGI middleware against a request/response interface. ### Usage To implement a middleware class using `BaseHTTPMiddleware`, you must override the `async def dispatch(request, call_next)` method. ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware class CustomHeaderMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) response.headers['Custom'] = 'Example' return response routes = ... middleware = [ Middleware(CustomHeaderMiddleware) ] app = Starlette(routes=routes, middleware=middleware) ``` If you want to provide configuration options to the middleware class you should override the `__init__` method, ensuring that the first argument is `app`, and any remaining arguments are optional keyword arguments. Make sure to set the `app` attribute on the instance if you do this. ```python class CustomHeaderMiddleware(BaseHTTPMiddleware): def __init__(self, app, header_value='Example'): super().__init__(app) self.header_value = header_value async def dispatch(self, request, call_next): response = await call_next(request) response.headers['Custom'] = self.header_value return response middleware = [ Middleware(CustomHeaderMiddleware, header_value='Customized') ] app = Starlette(routes=routes, middleware=middleware) ``` Middleware classes should not modify their state outside of the `__init__` method. Instead you should keep any state local to the `dispatch` method, or pass it around explicitly, rather than mutating the middleware instance. ### Limitations Currently, the `BaseHTTPMiddleware` has some known limitations: - Using `BaseHTTPMiddleware` will prevent changes to [`contextvars.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/Kludex/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). Importantly, this also means that if a `BaseHTTPMiddleware` is positioned earlier in the middleware stack, it will disrupt `contextvars` propagation for any subsequent Pure ASGI Middleware that relies on them. To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below. ## Pure ASGI Middleware The [ASGI spec](https://asgi.readthedocs.io/en/latest/) makes it possible to implement ASGI middleware using the ASGI interface directly, as a chain of ASGI applications that call into the next one. In fact, this is how middleware classes shipped with Starlette are implemented. This lower-level approach provides greater control over behavior and enhanced interoperability across frameworks and servers. It also overcomes the [limitations of `BaseHTTPMiddleware`](#limitations). ### Writing pure ASGI middleware The most common way to create an ASGI middleware is with a class. ```python class ASGIMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): await self.app(scope, receive, send) ``` The middleware above is the most basic ASGI middleware. It receives a parent ASGI application as an argument for its constructor, and implements an `async __call__` method which calls into that parent application. Some implementations such as [`asgi-cors`](https://github.com/simonw/asgi-cors/blob/10ef64bfcc6cd8d16f3014077f20a0fb8544ec39/asgi_cors.py) use an alternative style, using functions: ```python import functools def asgi_middleware(): def asgi_decorator(app): @functools.wraps(app) async def wrapped_app(scope, receive, send): await app(scope, receive, send) return wrapped_app return asgi_decorator ``` In any case, ASGI middleware must be callables that accept three arguments: `scope`, `receive`, and `send`. * `scope` is a dict holding information about the connection, where `scope["type"]` may be: * [`"http"`](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope): for HTTP requests. * [`"websocket"`](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope): for WebSocket connections. * [`"lifespan"`](https://asgi.readthedocs.io/en/latest/specs/lifespan.html#scope): for ASGI lifespan messages. * `receive` and `send` can be used to exchange ASGI event messages with the ASGI server — more on this below. The type and contents of these messages depend on the scope type. Learn more in the [ASGI specification](https://asgi.readthedocs.io/en/latest/specs/index.html). ### Using pure ASGI middleware Pure ASGI middleware can be used like any other middleware: ```python from starlette.applications import Starlette from starlette.middleware import Middleware from .middleware import ASGIMiddleware routes = ... middleware = [ Middleware(ASGIMiddleware), ] app = Starlette(..., middleware=middleware) ``` See also [Using middleware](#using-middleware). ### Type annotations There are two ways of annotating a middleware: using Starlette itself or [`asgiref`](https://github.com/django/asgiref). * Using Starlette: for most common use cases. ```python from starlette.types import ASGIApp, Message, Scope, Receive, Send class ASGIMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": return await self.app(scope, receive, send) async def send_wrapper(message: Message) -> None: # ... Do something await send(message) await self.app(scope, receive, send_wrapper) ``` * Using [`asgiref`](https://github.com/django/asgiref): for more rigorous type hinting. ```python from asgiref.typing import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope from asgiref.typing import ASGIReceiveEvent, ASGISendEvent class ASGIMiddleware: def __init__(self, app: ASGI3Application) -> None: self.app = app async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return async def send_wrapper(message: ASGISendEvent) -> None: # ... Do something await send(message) return await self.app(scope, receive, send_wrapper) ``` ### Common patterns #### Processing certain requests only ASGI middleware can apply specific behavior according to the contents of `scope`. For example, to only process HTTP requests, write this... ```python class ASGIMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return ... # Do something here! await self.app(scope, receive, send) ``` Likewise, WebSocket-only middleware would guard on `scope["type"] != "websocket"`. The middleware may also act differently based on the request method, URL, headers, etc. #### Reusing Starlette components Starlette provides several data structures that accept the ASGI `scope`, `receive` and/or `send` arguments, allowing you to work at a higher level of abstraction. Such data structures include [`Request`](requests.md#request), [`Headers`](requests.md#headers), [`QueryParams`](requests.md#query-parameters), [`URL`](requests.md#url), etc. For example, you can instantiate a `Request` to more easily inspect an HTTP request: ```python from starlette.requests import Request class ASGIMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] == "http": request = Request(scope) ... # Use `request.method`, `request.url`, `request.headers`, etc. await self.app(scope, receive, send) ``` You can also reuse [responses](responses.md), which are ASGI applications as well. #### Sending eager responses Inspecting the connection `scope` allows you to conditionally call into a different ASGI app. One use case might be sending a response without calling into the app. As an example, this middleware uses a dictionary to perform permanent redirects based on the requested path. This could be used to implement ongoing support of legacy URLs in case you need to refactor route URL patterns. ```python from starlette.datastructures import URL from starlette.responses import RedirectResponse class RedirectsMiddleware: def __init__(self, app, path_mapping: dict): self.app = app self.path_mapping = path_mapping async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return url = URL(scope=scope) if url.path in self.path_mapping: url = url.replace(path=self.path_mapping[url.path]) response = RedirectResponse(url, status_code=301) await response(scope, receive, send) return await self.app(scope, receive, send) ``` Example usage would look like this: ```python from starlette.applications import Starlette from starlette.middleware import Middleware routes = ... redirections = { "/v1/resource/": "/v2/resource/", # ... } middleware = [ Middleware(RedirectsMiddleware, path_mapping=redirections), ] app = Starlette(routes=routes, middleware=middleware) ``` #### Inspecting or modifying the request Request information can be accessed or changed by manipulating the `scope`. For a full example of this pattern, see Uvicorn's [`ProxyHeadersMiddleware`](https://github.com/encode/uvicorn/blob/fd4386fefb8fe8a4568831a7d8b2930d5fb61455/uvicorn/middleware/proxy_headers.py) which inspects and tweaks the `scope` when serving behind a frontend proxy. Besides, wrapping the `receive` ASGI callable allows you to access or modify the HTTP request body by manipulating [`http.request`](https://asgi.readthedocs.io/en/latest/specs/www.html#request-receive-event) ASGI event messages. As an example, this middleware computes and logs the size of the incoming request body... ```python class LoggedRequestBodySizeMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return body_size = 0 async def receive_logging_request_body_size(): nonlocal body_size message = await receive() assert message["type"] == "http.request" body_size += len(message.get("body", b"")) if not message.get("more_body", False): print(f"Size of request body was: {body_size} bytes") return message await self.app(scope, receive_logging_request_body_size, send) ``` Likewise, WebSocket middleware may manipulate [`websocket.receive`](https://asgi.readthedocs.io/en/latest/specs/www.html#receive-receive-event) ASGI event messages to inspect or alter incoming WebSocket data. For an example that changes the HTTP request body, see [`msgpack-asgi`](https://github.com/florimondmanca/msgpack-asgi). #### Inspecting or modifying the response Wrapping the `send` ASGI callable allows you to inspect or modify the HTTP response sent by the underlying application. To do so, react to [`http.response.start`](https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event) or [`http.response.body`](https://asgi.readthedocs.io/en/latest/specs/www.html#response-body-send-event) ASGI event messages. As an example, this middleware adds some fixed extra response headers: ```python from starlette.datastructures import MutableHeaders class ExtraResponseHeadersMiddleware: def __init__(self, app, headers): self.app = app self.headers = headers async def __call__(self, scope, receive, send): if scope["type"] != "http": return await self.app(scope, receive, send) async def send_with_extra_headers(message): if message["type"] == "http.response.start": headers = MutableHeaders(scope=message) for key, value in self.headers: headers.append(key, value) await send(message) await self.app(scope, receive, send_with_extra_headers) ``` See also [`asgi-logger`](https://github.com/Kludex/asgi-logger/blob/main/asgi_logger/middleware.py) for an example that inspects the HTTP response and logs a configurable HTTP access log line. Likewise, WebSocket middleware may manipulate [`websocket.send`](https://asgi.readthedocs.io/en/latest/specs/www.html#send-send-event) ASGI event messages to inspect or alter outgoing WebSocket data. Note that if you change the response body, you will need to update the response `Content-Length` header to match the new response body length. See [`brotli-asgi`](https://github.com/fullonic/brotli-asgi) for a complete example. #### Passing information to endpoints If you need to share information with the underlying app or endpoints, you may store it into the `scope` dictionary. Note that this is a convention -- for example, Starlette uses this to share routing information with endpoints -- but it is not part of the ASGI specification. If you do so, be sure to avoid conflicts by using keys that have low chances of being used by other middleware or applications. For example, when including the middleware below, endpoints would be able to access `request.scope["asgi_transaction_id"]`. ```python import uuid class TransactionIDMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): scope["asgi_transaction_id"] = uuid.uuid4() await self.app(scope, receive, send) ``` #### Cleanup and error handling You can wrap the application in a `try/except/finally` block or a context manager to perform cleanup operations or do error handling. For example, the following middleware might collect metrics and process application exceptions... ```python import time class MonitoringMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): start = time.time() try: await self.app(scope, receive, send) except Exception as exc: ... # Process the exception raise finally: end = time.time() elapsed = end - start ... # Submit `elapsed` as a metric to a monitoring backend ``` See also [`timing-asgi`](https://github.com/steinnes/timing-asgi) for a full example of this pattern. ### Gotchas #### ASGI middleware should be stateless Because ASGI is designed to handle concurrent requests, any connection-specific state should be scoped to the `__call__` implementation. Not doing so would typically lead to conflicting variable reads/writes across requests, and most likely bugs. As an example, this would conditionally replace the response body, if an `X-Mock` header is present in the response... === "✅ Do" ```python from starlette.datastructures import Headers class MockResponseBodyMiddleware: def __init__(self, app, content): self.app = app self.content = content async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return # A flag that we will turn `True` if the HTTP response # has the 'X-Mock' header. # ✅: Scoped to this function. should_mock = False async def maybe_send_with_mock_content(message): nonlocal should_mock if message["type"] == "http.response.start": headers = Headers(raw=message["headers"]) should_mock = headers.get("X-Mock") == "1" await send(message) elif message["type"] == "http.response.body": if should_mock: message = {"type": "http.response.body", "body": self.content} await send(message) await self.app(scope, receive, maybe_send_with_mock_content) ``` === "❌ Don't" ```python hl_lines="7-8" from starlette.datastructures import Headers class MockResponseBodyMiddleware: def __init__(self, app, content): self.app = app self.content = content # ❌: This variable would be read and written across requests! self.should_mock = False async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return async def maybe_send_with_mock_content(message): if message["type"] == "http.response.start": headers = Headers(raw=message["headers"]) self.should_mock = headers.get("X-Mock") == "1" await send(message) elif message["type"] == "http.response.body": if self.should_mock: message = {"type": "http.response.body", "body": self.content} await send(message) await self.app(scope, receive, maybe_send_with_mock_content) ``` See also [`GZipMiddleware`](https://github.com/Kludex/starlette/blob/9ef1b91c9c043197da6c3f38aa153fd874b95527/starlette/middleware/gzip.py) for a full example implementation that navigates this potential gotcha. ### Further reading This documentation should be enough to have a good basis on how to create an ASGI middleware. Nonetheless, there are great articles about the subject: - [Introduction to ASGI: Emergence of an Async Python Web Ecosystem](https://florimond.dev/en/posts/2019/08/introduction-to-asgi-async-python-web/) - [How to write ASGI middleware](https://pgjones.dev/blog/how-to-write-asgi-middleware-2021/) ## Using middleware in other frameworks To wrap ASGI middleware around other ASGI applications, you should use the more general pattern of wrapping the application instance: ```python app = TrustedHostMiddleware(app, allowed_hosts=['example.com']) ``` You can do this with a Starlette application instance too, but it is preferable to use the `middleware=` style, as it will: * Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`. * Preserves the top-level `app` instance. ## Applying middleware to groups of routes Middleware can also be added to `Mount` instances, which allows you to apply middleware to a group of routes or a sub-application: ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.routing import Mount, Route routes = [ Mount( "/", routes=[ Route( "/example", endpoint=..., ) ], middleware=[Middleware(GZipMiddleware)] ) ] app = Starlette(routes=routes) ``` Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is. This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses. If you do want to apply the middleware logic to error responses only on some routes you have a couple of options: * Add an `ExceptionMiddleware` onto the `Mount` * Add a `try/except` block to your middleware and return an error response from there * Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting. The `Route`/`WebSocket` class also accepts a `middleware` argument, which allows you to apply middleware to a single route: ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.routing import Route routes = [ Route( "/example", endpoint=..., middleware=[Middleware(GZipMiddleware)] ) ] app = Starlette(routes=routes) ``` You can also apply middleware to the `Router` class, which allows you to apply middleware to a group of routes: ```python from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.routing import Route, Router routes = [ Route("/example", endpoint=...), Route("/another", endpoint=...), ] router = Router(routes=routes, middleware=[Middleware(GZipMiddleware)]) ``` ## Third party middleware #### [asgi-auth-github](https://github.com/simonw/asgi-auth-github) This middleware adds authentication to any ASGI application, requiring users to sign in using their GitHub account (via [OAuth](https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps/)). Access can be restricted to specific users or to members of specific GitHub organizations or teams. #### [asgi-csrf](https://github.com/simonw/asgi-csrf) Middleware for protecting against CSRF attacks. This middleware implements the Double Submit Cookie pattern, where a cookie is set, then it is compared to a csrftoken hidden form field or an `x-csrftoken` HTTP header. #### [AuthlibMiddleware](https://github.com/aogier/starlette-authlib) A drop-in replacement for Starlette session middleware, using [authlib's jwt](https://docs.authlib.org/en/latest/jose/jwt.html) module. #### [BugsnagMiddleware](https://github.com/ashinabraham/starlette-bugsnag) A middleware class for logging exceptions to [Bugsnag](https://www.bugsnag.com/). #### [CSRFMiddleware](https://github.com/frankie567/starlette-csrf) Middleware for protecting against CSRF attacks. This middleware implements the Double Submit Cookie pattern, where a cookie is set, then it is compared to an `x-csrftoken` HTTP header. #### [EarlyDataMiddleware](https://github.com/HarrySky/starlette-early-data) Middleware and decorator for detecting and denying [TLSv1.3 early data](https://tools.ietf.org/html/rfc8470) requests. #### [PrometheusMiddleware](https://github.com/perdy/starlette-prometheus) A middleware class for capturing Prometheus metrics related to requests and responses, including in progress requests, timing... #### [ProxyHeadersMiddleware](https://github.com/encode/uvicorn/blob/main/uvicorn/middleware/proxy_headers.py) Uvicorn includes a middleware class for determining the client IP address, when proxy servers are being used, based on the `X-Forwarded-Proto` and `X-Forwarded-For` headers. For more complex proxy configurations, you might want to adapt this middleware. #### [RateLimitMiddleware](https://github.com/abersheeran/asgi-ratelimit) A rate limit middleware. Regular expression matches url; flexible rules; highly customizable. Very easy to use. #### [RequestIdMiddleware](https://github.com/snok/asgi-correlation-id) A middleware class for reading/generating request IDs and attaching them to application logs. #### [RollbarMiddleware](https://docs.rollbar.com/docs/starlette) A middleware class for logging exceptions, errors, and log messages to [Rollbar](https://www.rollbar.com). #### [StarletteOpentracing](https://github.com/acidjunk/starlette-opentracing) A middleware class that emits tracing info to [OpenTracing.io](https://opentracing.io/) compatible tracers and can be used to profile and monitor distributed applications. #### [SecureCookiesMiddleware](https://github.com/thearchitector/starlette-securecookies) Customizable middleware for adding automatic cookie encryption and decryption to Starlette applications, with extra support for existing cookie-based middleware. #### [TimingMiddleware](https://github.com/steinnes/timing-asgi) A middleware class to emit timing information (cpu and wall time) for each request which passes through it. Includes examples for how to emit these timings as statsd metrics. #### [WSGIMiddleware](https://github.com/abersheeran/a2wsgi) A middleware class in charge of converting a WSGI application into an ASGI one. ================================================ FILE: docs/overrides/main.html ================================================ {% extends "base.html" %} {% block extrahead %} {{ super() }} {% endblock %} {% block announce %}
{% endblock %} ================================================ FILE: docs/overrides/partials/toc-item.html ================================================
  • {{ toc_item.title }} {% if toc_item.children %} {% endif %} ================================================ FILE: docs/release-notes.md ================================================ --- toc_depth: 2 --- ## 1.0.0rc1 (February 23, 2026) We're ready! I'm thrilled to announce the first release candidate for Starlette 1.0. Starlette was created in June 2018 by Tom Christie, and has been on ZeroVer for years. Today, it's downloaded almost [10 million times a day](https://pypistats.org/packages/starlette), serves as the foundation for FastAPI, and has inspired many other frameworks. In the age of AI, Starlette continues to play an important role as a dependency of the Python MCP SDK. This release focuses on removing deprecated features that were marked for removal in 1.0.0, along with some last minute bug fixes. It's a release candidate, so we can gather feedback from the community before the final 1.0.0 release soon. A huge thank you to all the contributors who have helped make Starlette what it is today. In particular, I'd like to recognize: * [Kim Christie](https://github.com/lovelydinosaur) - The original creator of Starlette, Uvicorn, and MkDocs, and the current maintainer of HTTPX. Kim's work helped lay the foundation for the modern async Python ecosystem. * [Adrian Garcia Badaracco](https://github.com/adriangb) - One of the smartest people I know, whom I have the pleasure of working with at Pydantic. * [Thomas Grainger](https://github.com/graingert) - My async teacher, always ready to help with questions. * [Alex Grönholm](https://github.com/agronholm) - Another async mentor, always prompt to help with questions. * [Florimond Manca](https://github.com/florimondmanca) - Always present in the early days of both Starlette and Uvicorn, and helped a lot in the ecosystem. * [Amin Alaee](https://github.com/aminalaee) - Contributed a lot with file-related PRs. * [Sebastián Ramírez](https://github.com/tiangolo) - Maintains FastAPI upstream, and always in contact to help with upstream issues. * [Alex Oleshkevich](https://github.com/alex-oleshkevich) - Helped a lot on templates and many discussions. * [abersheeran](https://github.com/abersheeran) - My go-to person when I need help on many subjects. I'd also like to thank my sponsors for their support. A special thanks to [@tiangolo](https://github.com/tiangolo), [@huggingface](https://github.com/huggingface), and [@elevenlabs](https://github.com/elevenlabs) for their generous sponsorship, and to all my other sponsors: [@roboflow](https://github.com/roboflow), [@ogabrielluiz](https://github.com/ogabrielluiz), [@SaboniAmine](https://github.com/SaboniAmine), [@russbiggs](https://github.com/russbiggs), [@BryceBeagle](https://github.com/BryceBeagle), [@chdsbd](https://github.com/chdsbd), [@TheR1D](https://github.com/TheR1D), [@ddanier](https://github.com/ddanier), [@larsyngvelundin](https://github.com/larsyngvelundin), [@jpizquierdo](https://github.com/jpizquierdo), [@alixlahuec](https://github.com/alixlahuec), [@nathanchapman](https://github.com/nathanchapman), [@devid8642](https://github.com/devid8642), [@comet-ml](https://github.com/comet-ml), [@Evil0ctal](https://github.com/Evil0ctal), [@msehnout](https://github.com/msehnout), and [@codingjoe](https://github.com/codingjoe). #### Removed * Remove `on_startup` and `on_shutdown` parameters from `Starlette` and `Router`. Use the `lifespan` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `on_event()` decorator from `Starlette` and `Router`. Use the `lifespan` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `add_event_handler()` method from `Starlette` and `Router`. Use the `lifespan` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `startup()` and `shutdown()` methods from `Router` [#3117](https://github.com/encode/starlette/pull/3117). * Remove `@app.route()` decorator from `Starlette` and `Router`. Use `Route` in the `routes` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `@app.websocket_route()` decorator from `Starlette` and `Router`. Use `WebSocketRoute` in the `routes` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `@app.exception_handler()` decorator from `Starlette`. Use `exception_handlers` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `@app.middleware()` decorator from `Starlette`. Use `middleware` parameter instead [#3117](https://github.com/encode/starlette/pull/3117). * Remove `iscoroutinefunction_or_partial()` from `starlette.routing` [#3117](https://github.com/encode/starlette/pull/3117). * Remove `**env_options` parameter from `Jinja2Templates`. Use a preconfigured `jinja2.Environment` via the `env` parameter instead [#3118](https://github.com/encode/starlette/pull/3118). * Remove deprecated `TemplateResponse(name, context)` signature from `Jinja2Templates`. Use `TemplateResponse(request, name, ...)` instead [#3118](https://github.com/encode/starlette/pull/3118). * Remove deprecated `method` parameter from `FileResponse` [#3147](https://github.com/encode/starlette/pull/3147). #### Added * Add state generic to `WebSocket` [#3132](https://github.com/encode/starlette/pull/3132). #### Fixed * Include `bytes` unit in `Content-Range` header on 416 responses. * Handle null bytes in `StaticFiles` path [#3139](https://github.com/encode/starlette/pull/3139). * Use sort-based merge for `Range` header parsing [#3138](https://github.com/encode/starlette/pull/3138). * Set `Content-Type` instead of `Content-Range` on multi-range responses [#3142](https://github.com/encode/starlette/pull/3142). * Use CRLF line endings in multipart byterange boundaries [#3143](https://github.com/encode/starlette/pull/3143). * Avoid mutating `FileResponse` headers on range requests [#3144](https://github.com/encode/starlette/pull/3144). * Return explicit origin in CORS response when credentials are allowed [#3137](https://github.com/encode/starlette/pull/3137). * Enable `autoescape` by default in `Jinja2Templates` [#3148](https://github.com/encode/starlette/pull/3148). #### Changed * `jinja2` must now be installed to import `Jinja2Templates`. Previously it would only fail when instantiating the class [#3118](https://github.com/encode/starlette/pull/3118). ## 0.52.1 (January 18, 2026) #### Fixed * Only use `typing_extensions` in older Python versions [#3109](https://github.com/Kludex/starlette/pull/3109). ## 0.52.0 (January 18, 2026) In this release, `State` can be accessed using dictionary-style syntax for improved type safety ([#3036](https://github.com/Kludex/starlette/pull/3036)). ```python from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import TypedDict import httpx from starlette.applications import Starlette from starlette.requests import Request class State(TypedDict): http_client: httpx.AsyncClient @asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[State]: async with httpx.AsyncClient() as client: yield {"http_client": client} async def homepage(request: Request[State]): client = request.state["http_client"] # If you run the below line with mypy or pyright, it will reveal the correct type. reveal_type(client) # Revealed type is 'httpx.AsyncClient' ``` See [Accessing State](lifespan.md#accessing-state) for more details. ## 0.51.0 (January 10, 2026) #### Added * Add `allow_private_network` in `CORSMiddleware` [#3065](https://github.com/Kludex/starlette/pull/3065). #### Changed * Increase warning stacklevel on `DeprecationWarning` for wsgi module [#3082](https://github.com/Kludex/starlette/pull/3082). ## 0.50.0 (November 1, 2025) #### Removed * Drop Python 3.9 support [#3061](https://github.com/Kludex/starlette/pull/3061). ## 0.49.3 (November 1, 2025) This is the last release that supports Python 3.9, which will be dropped in the next minor release. #### Fixed * Relax strictness on `Middleware` type [#3059](https://github.com/Kludex/starlette/pull/3059). ## 0.49.2 (November 1, 2025) #### Fixed * Ignore `if-modified-since` header if `if-none-match` is present in `StaticFiles` [#3044](https://github.com/Kludex/starlette/pull/3044). ## 0.49.1 (October 28, 2025) This release fixes a security vulnerability in the parsing logic of the `Range` header in `FileResponse`. You can view the full security advisory: [GHSA-7f5h-v6xp-fcq8](https://github.com/Kludex/starlette/security/advisories/GHSA-7f5h-v6xp-fcq8) #### Fixed * Optimize the HTTP ranges parsing logic [4ea6e22b489ec388d6004cfbca52dd5b147127c5](https://github.com/Kludex/starlette/commit/4ea6e22b489ec388d6004cfbca52dd5b147127c5) ## 0.49.0 (October 28, 2025) #### Added * Add `encoding` parameter to `Config` class [#2996](https://github.com/Kludex/starlette/pull/2996). * Support multiple cookie headers in `Request.cookies` [#3029](https://github.com/Kludex/starlette/pull/3029). * Use `Literal` type for `WebSocketEndpoint` encoding values [#3027](https://github.com/Kludex/starlette/pull/3027). #### Changed * Do not pollute exception context in `Middleware` when using `BaseHTTPMiddleware` [#2976](https://github.com/Kludex/starlette/pull/2976). ## 0.48.0 (September 13, 2025) #### Added * Add official Python 3.14 support [#3013](https://github.com/Kludex/starlette/pull/3013). #### Changed * Implement [RFC9110](https://www.rfc-editor.org/rfc/rfc9110) http status names [#2939](https://github.com/Kludex/starlette/pull/2939). ## 0.47.3 (August 24, 2025) #### Fixed * Use `asyncio.iscoroutinefunction` for Python 3.12 and older [#2984](https://github.com/Kludex/starlette/pull/2984). ## 0.47.2 (July 20, 2025) #### Fixed * Make `UploadFile` check for future rollover [#2962](https://github.com/Kludex/starlette/pull/2962). ## 0.47.1 (June 21, 2025) #### Fixed * Use `Self` in `TestClient.__enter__` [#2951](https://github.com/Kludex/starlette/pull/2951). * Allow async exception handlers to type-check [#2949](https://github.com/Kludex/starlette/pull/2949). ## 0.47.0 (May 29, 2025) #### Added * Add support for ASGI `pathsend` extension [#2671](https://github.com/Kludex/starlette/pull/2671). * Add `partitioned` attribute to `Response.set_cookie` [#2501](https://github.com/Kludex/starlette/pull/2501). #### Changed * Change `methods` parameter type from `list[str]` to `Collection[str]` [#2903](https://github.com/Kludex/starlette/pull/2903). * Replace `import typing` by `from typing import ...` in the whole codebase [#2867](https://github.com/Kludex/starlette/pull/2867). #### Fixed * Mark `ExceptionMiddleware.http_exception` as async to prevent thread creation [#2922](https://github.com/Kludex/starlette/pull/2922). ## 0.46.2 (April 13, 2025) #### Fixed * Prevents reraising of exception from BaseHttpMiddleware [#2911](https://github.com/Kludex/starlette/pull/2911). * Use correct index on backwards compatible logic in `TemplateResponse` [#2909](https://github.com/Kludex/starlette/pull/2909). ## 0.46.1 (March 8, 2025) #### Fixed * Allow relative directory path when `follow_symlinks=True` [#2896](https://github.com/Kludex/starlette/pull/2896). ## 0.46.0 (February 22, 2025) #### Added * `GZipMiddleware`: Make sure `Vary` header is always added if a response can be compressed [#2865](https://github.com/Kludex/starlette/pull/2865). #### Fixed * Raise exception from background task on BaseHTTPMiddleware [#2812](https://github.com/Kludex/starlette/pull/2812). * `GZipMiddleware`: Don't compress on server sent events [#2871](https://github.com/Kludex/starlette/pull/2871). #### Changed * `MultiPartParser`: Rename `max_file_size` to `spool_max_size` [#2780](https://github.com/Kludex/starlette/pull/2780). #### Deprecated * Add deprecated warning to `TestClient(timeout=...)` [#2840](https://github.com/Kludex/starlette/pull/2840). ## 0.45.3 (January 24, 2025) #### Fixed * Turn directory into string on `lookup_path` on commonpath comparison [#2851](https://github.com/Kludex/starlette/pull/2851). ## 0.45.2 (January 4, 2025) #### Fixed * Make `create_memory_object_stream` compatible with old anyio versions once again, and bump anyio minimum version to 3.6.2 [#2833](https://github.com/Kludex/starlette/pull/2833). ## 0.45.1 (December 30, 2024) #### Fixed * Close `MemoryObjectReceiveStream` left unclosed upon exception in `BaseHTTPMiddleware` children [#2813](https://github.com/Kludex/starlette/pull/2813). * Collect errors more reliably from the WebSocket logic on the `TestClient` [#2814](https://github.com/Kludex/starlette/pull/2814). #### Refactor * Use a pair of memory object streams instead of two queues on the `TestClient` [#2829](https://github.com/Kludex/starlette/pull/2829). ## 0.45.0 (December 29, 2024) #### Removed * Drop Python 3.8 support [#2823](https://github.com/Kludex/starlette/pull/2823). * Remove `ExceptionMiddleware` import proxy from `starlette.exceptions` module [#2826](https://github.com/Kludex/starlette/pull/2826). * Remove deprecated `WS_1004_NO_STATUS_RCVD` and `WS_1005_ABNORMAL_CLOSURE` [#2827](https://github.com/Kludex/starlette/pull/2827). ## 0.44.0 (December 28, 2024) #### Added * Add `client` parameter to `TestClient` [#2810](https://github.com/Kludex/starlette/pull/2810). * Add `max_part_size` parameter to `Request.form()` [#2815](https://github.com/Kludex/starlette/pull/2815). ## 0.43.0 (December 25, 2024) #### Removed * Remove deprecated `allow_redirects` argument from `TestClient` [#2808](https://github.com/Kludex/starlette/pull/2808). #### Added * Make UUID path parameter conversion more flexible [#2806](https://github.com/Kludex/starlette/pull/2806). ## 0.42.0 (December 14, 2024) #### Added * Raise `ClientDisconnect` on `StreamingResponse` [#2732](https://github.com/Kludex/starlette/pull/2732). #### Fixed * Use ETag from headers when parsing If-Range in FileResponse [#2761](https://github.com/Kludex/starlette/pull/2761). * Follow directory symlinks in `StaticFiles` when `follow_symlinks=True` [#2711](https://github.com/Kludex/starlette/pull/2711). * Bump minimum `python-multipart` version to `0.0.18` [0ba8395](https://github.com/Kludex/starlette/commit/0ba83959e609bbd460966f092287df1bbd564cc6). * Bump minimum `httpx` version to `0.27.0` [#2773](https://github.com/Kludex/starlette/pull/2773). ## 0.41.3 (November 18, 2024) #### Fixed * Exclude the query parameters from the `scope[raw_path]` on the `TestClient` [#2716](https://github.com/Kludex/starlette/pull/2716). * Replace `dict` by `Mapping` on `HTTPException.headers` [#2749](https://github.com/Kludex/starlette/pull/2749). * Correct middleware argument passing and improve factory pattern [#2752](https://github.com/Kludex/starlette/pull/2752). ## 0.41.2 (October 27, 2024) #### Fixed * Revert bump on `python-multipart` on `starlette[full]` extras [#2737](https://github.com/Kludex/starlette/pull/2737). ## 0.41.1 (October 24, 2024) #### Fixed * Bump minimum `python-multipart` version to `0.0.13` [#2734](https://github.com/Kludex/starlette/pull/2734). * Change `python-multipart` import to `python_multipart` [#2733](https://github.com/Kludex/starlette/pull/2733). ## 0.41.0 (October 15, 2024) #### Added - Allow to raise `HTTPException` before `websocket.accept()` [#2725](https://github.com/Kludex/starlette/pull/2725). ## 0.40.0 (October 15, 2024) This release fixes a Denial of service (DoS) via `multipart/form-data` requests. You can view the full security advisory: [GHSA-f96h-pmfr-66vw](https://github.com/Kludex/starlette/security/advisories/GHSA-f96h-pmfr-66vw) #### Fixed - Add `max_part_size` to `MultiPartParser` to limit the size of parts in `multipart/form-data` requests [fd038f3](https://github.com/Kludex/starlette/commit/fd038f3070c302bff17ef7d173dbb0b007617733). ## 0.39.2 (September 29, 2024) #### Fixed - Allow use of `request.url_for` when only "app" scope is available [#2672](https://github.com/Kludex/starlette/pull/2672). - Fix internal type hints to support `python-multipart==0.0.12` [#2708](https://github.com/Kludex/starlette/pull/2708). ## 0.39.1 (September 25, 2024) #### Fixed - Avoid regex re-compilation in `responses.py` and `schemas.py` [#2700](https://github.com/Kludex/starlette/pull/2700). - Improve performance of `get_route_path` by removing regular expression usage [#2701](https://github.com/Kludex/starlette/pull/2701). - Consider `FileResponse.chunk_size` when handling multiple ranges [#2703](https://github.com/Kludex/starlette/pull/2703). - Use `token_hex` for generating multipart boundary strings [#2702](https://github.com/Kludex/starlette/pull/2702). ## 0.39.0 (September 23, 2024) #### Added * Add support for [HTTP Range](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests) to `FileResponse` [#2697](https://github.com/Kludex/starlette/pull/2697). ## 0.38.6 (September 22, 2024) #### Fixed * Close unclosed `MemoryObjectReceiveStream` in `TestClient` [#2693](https://github.com/Kludex/starlette/pull/2693). ## 0.38.5 (September 7, 2024) #### Fixed * Schedule `BackgroundTasks` from within `BaseHTTPMiddleware` [#2688](https://github.com/Kludex/starlette/pull/2688). This behavior was removed in 0.38.3, and is now restored. ## 0.38.4 (September 1, 2024) #### Fixed * Ensure accurate `root_path` removal in `get_route_path` function [#2600](https://github.com/Kludex/starlette/pull/2600). ## 0.38.3 (September 1, 2024) #### Added * Support for Python 3.13 [#2662](https://github.com/Kludex/starlette/pull/2662). #### Fixed * Don't poll for disconnects in `BaseHTTPMiddleware` via `StreamingResponse` [#2620](https://github.com/Kludex/starlette/pull/2620). ## 0.38.2 (July 27, 2024) #### Fixed * Not assume all routines have `__name__` on `routing.get_name()` [#2648](https://github.com/Kludex/starlette/pull/2648). ## 0.38.1 (July 23, 2024) #### Removed * Revert "Add support for ASGI pathsend extension" [#2649](https://github.com/Kludex/starlette/pull/2649). ## 0.38.0 (July 20, 2024) #### Added * Allow use of `memoryview` in `StreamingResponse` and `Response` [#2576](https://github.com/Kludex/starlette/pull/2576) and [#2577](https://github.com/Kludex/starlette/pull/2577). * Send 404 instead of 500 when filename requested is too long on `StaticFiles` [#2583](https://github.com/Kludex/starlette/pull/2583). #### Changed * Fail fast on invalid `Jinja2Template` instantiation parameters [#2568](https://github.com/Kludex/starlette/pull/2568). * Check endpoint handler is async only once [#2536](https://github.com/Kludex/starlette/pull/2536). #### Fixed * Add proper synchronization to `WebSocketTestSession` [#2597](https://github.com/Kludex/starlette/pull/2597). ## 0.37.2 (March 5, 2024) #### Added * Add `bytes` to `_RequestData` type [#2510](https://github.com/Kludex/starlette/pull/2510). #### Fixed * Revert "Turn `scope["client"]` to `None` on `TestClient` (#2377)" [#2525](https://github.com/Kludex/starlette/pull/2525). * Remove deprecated `app` argument passed to `httpx.Client` on the `TestClient` [#2526](https://github.com/Kludex/starlette/pull/2526). ## 0.37.1 (February 9, 2024) #### Fixed * Warn instead of raise for missing env file on `Config` [#2485](https://github.com/Kludex/starlette/pull/2485). ## 0.37.0 (February 5, 2024) #### Added * Support the WebSocket Denial Response ASGI extension [#2041](https://github.com/Kludex/starlette/pull/2041). ## 0.36.3 (February 4, 2024) #### Fixed * Create `anyio.Event` on async context [#2459](https://github.com/Kludex/starlette/pull/2459). ## 0.36.2 (February 3, 2024) #### Fixed * Upgrade `python-multipart` to `0.0.7` [13e5c26](http://github.com/Kludex/starlette/commit/13e5c26a27f4903924624736abd6131b2da80cc5). * Avoid duplicate charset on `Content-Type` [#2443](https://github.com/Kludex/starlette/2443). ## 0.36.1 (January 23, 2024) #### Fixed * Check if "extensions" in scope before checking the extension [#2438](http://github.com/Kludex/starlette/pull/2438). ## 0.36.0 (January 22, 2024) #### Added * Add support for ASGI `pathsend` extension [#2435](http://github.com/Kludex/starlette/pull/2435). * Cancel `WebSocketTestSession` on close [#2427](http://github.com/Kludex/starlette/pull/2427). * Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError` [#2425](http://github.com/Kludex/starlette/pull/2425). * Raise `FileNotFoundError` when the `env_file` parameter on `Config` is not valid [#2422](http://github.com/Kludex/starlette/pull/2422). ## 0.35.1 (January 11, 2024) #### Fixed * Stop using the deprecated "method" parameter in `FileResponse` inside of `StaticFiles` [#2406](https://github.com/Kludex/starlette/pull/2406). * Make `typing-extensions` optional again [#2409](https://github.com/Kludex/starlette/pull/2409). ## 0.35.0 (January 11, 2024) #### Added * Add `*args` to `Middleware` and improve its type hints [#2381](https://github.com/Kludex/starlette/pull/2381). #### Fixed * Use `Iterable` instead `Iterator` on `iterate_in_threadpool` [#2362](https://github.com/Kludex/starlette/pull/2362). #### Changes * Handle `root_path` to keep compatibility with mounted ASGI applications and WSGI [#2400](https://github.com/Kludex/starlette/pull/2400). * Turn `scope["client"]` to `None` on `TestClient` [#2377](https://github.com/Kludex/starlette/pull/2377). ## 0.34.0 (December 16, 2023) ### Added * Use `ParamSpec` for `run_in_threadpool` [#2375](https://github.com/Kludex/starlette/pull/2375). * Add `UploadFile.__repr__` [#2360](https://github.com/Kludex/starlette/pull/2360). ### Fixed * Merge URLs properly on `TestClient` [#2376](https://github.com/Kludex/starlette/pull/2376). * Take weak ETags in consideration on `StaticFiles` [#2334](https://github.com/Kludex/starlette/pull/2334). ### Deprecated * Deprecate `FileResponse(method=...)` parameter [#2366](https://github.com/Kludex/starlette/pull/2366). ## 0.33.0 (December 1, 2023) ### Added * Add `middleware` per `Route`/`WebSocketRoute` [#2349](https://github.com/Kludex/starlette/pull/2349). * Add `middleware` per `Router` [#2351](https://github.com/Kludex/starlette/pull/2351). ### Fixed * Do not overwrite `"path"` and `"root_path"` scope keys [#2352](https://github.com/Kludex/starlette/pull/2352). * Set `ensure_ascii=False` on `json.dumps()` for `WebSocket.send_json()` [#2341](https://github.com/Kludex/starlette/pull/2341). ## 0.32.0.post1 (November 5, 2023) ### Fixed * Revert mkdocs-material from 9.1.17 to 9.4.7 [#2326](https://github.com/Kludex/starlette/pull/2326). ## 0.32.0 (November 4, 2023) ### Added * Send `reason` on `WebSocketDisconnect` [#2309](https://github.com/Kludex/starlette/pull/2309). * Add `domain` parameter to `SessionMiddleware` [#2280](https://github.com/Kludex/starlette/pull/2280). ### Changed * Inherit from `HTMLResponse` instead of `Response` on `_TemplateResponse` [#2274](https://github.com/Kludex/starlette/pull/2274). * Restore the `Response.render` type annotation to its pre-0.31.0 state [#2264](https://github.com/Kludex/starlette/pull/2264). ## 0.31.1 (August 26, 2023) ### Fixed * Fix import error when `exceptiongroup` isn't available [#2231](https://github.com/Kludex/starlette/pull/2231). * Set `url_for` global for custom Jinja environments [#2230](https://github.com/Kludex/starlette/pull/2230). ## 0.31.0 (July 24, 2023) ### Added * Officially support Python 3.12 [#2214](https://github.com/Kludex/starlette/pull/2214). * Support AnyIO 4.0 [#2211](https://github.com/Kludex/starlette/pull/2211). * Strictly type annotate Starlette (strict mode on mypy) [#2180](https://github.com/Kludex/starlette/pull/2180). ### Fixed * Don't group duplicated headers on a single string when using the `TestClient` [#2219](https://github.com/Kludex/starlette/pull/2219). ## 0.30.0 (July 13, 2023) ### Removed * Drop Python 3.7 support [#2178](https://github.com/Kludex/starlette/pull/2178). ## 0.29.0 (July 13, 2023) ### Added * Add `follow_redirects` parameter to `TestClient` [#2207](https://github.com/Kludex/starlette/pull/2207). * Add `__str__` to `HTTPException` and `WebSocketException` [#2181](https://github.com/Kludex/starlette/pull/2181). * Warn users when using `lifespan` together with `on_startup`/`on_shutdown` [#2193](https://github.com/Kludex/starlette/pull/2193). * Collect routes from `Host` to generate the OpenAPI schema [#2183](https://github.com/Kludex/starlette/pull/2183). * Add `request` argument to `TemplateResponse` [#2191](https://github.com/Kludex/starlette/pull/2191). ### Fixed * Stop `body_stream` in case `more_body=False` on `BaseHTTPMiddleware` [#2194](https://github.com/Kludex/starlette/pull/2194). ## 0.28.0 (June 7, 2023) ### Changed * Reuse `Request`'s body buffer for call_next in `BaseHTTPMiddleware` [#1692](https://github.com/Kludex/starlette/pull/1692). * Move exception handling logic to `Route` [#2026](https://github.com/Kludex/starlette/pull/2026). ### Added * Add `env` parameter to `Jinja2Templates`, and deprecate `**env_options` [#2159](https://github.com/Kludex/starlette/pull/2159). * Add clear error message when `httpx` is not installed [#2177](https://github.com/Kludex/starlette/pull/2177). ### Fixed * Allow "name" argument on `templates url_for()` [#2127](https://github.com/Kludex/starlette/pull/2127). ## 0.27.0 (May 16, 2023) This release fixes a path traversal vulnerability in `StaticFiles`. You can view the full security advisory: https://github.com/Kludex/starlette/security/advisories/GHSA-v5gw-mw7f-84px ### Added * Minify JSON websocket data via `send_json` https://github.com/Kludex/starlette/pull/2128 ### Fixed * Replace `commonprefix` by `commonpath` on `StaticFiles` [1797de4](https://github.com/Kludex/starlette/commit/1797de464124b090f10cf570441e8292936d63e3). * Convert ImportErrors into ModuleNotFoundError [#2135](https://github.com/Kludex/starlette/pull/2135). * Correct the RuntimeError message content in websockets [#2141](https://github.com/Kludex/starlette/pull/2141). ## 0.26.1 (March 13, 2023) ### Fixed * Fix typing of Lifespan to allow subclasses of Starlette [#2077](https://github.com/Kludex/starlette/pull/2077). ## 0.26.0.post1 (March 9, 2023) ### Fixed * Replace reference from Events to Lifespan on the mkdocs.yml [#2072](https://github.com/Kludex/starlette/pull/2072). ## 0.26.0 (March 9, 2023) ### Added * Support [lifespan state](lifespan.md) [#2060](https://github.com/Kludex/starlette/pull/2060), [#2065](https://github.com/Kludex/starlette/pull/2065) and [#2064](https://github.com/Kludex/starlette/pull/2064). ### Changed * Change `url_for` signature to return a `URL` instance [#1385](https://github.com/Kludex/starlette/pull/1385). ### Fixed * Allow "name" argument on `url_for()` and `url_path_for()` [#2050](https://github.com/Kludex/starlette/pull/2050). ### Deprecated * Deprecate `on_startup` and `on_shutdown` events [#2070](https://github.com/Kludex/starlette/pull/2070). ## 0.25.0 (February 14, 2023) ### Fix * Limit the number of fields and files when parsing `multipart/form-data` on the `MultipartParser` [8c74c2c](https://github.com/Kludex/starlette/commit/8c74c2c8dba7030154f8af18e016136bea1938fa) and [#2036](https://github.com/Kludex/starlette/pull/2036). ## 0.24.0 (February 6, 2023) ### Added * Allow `StaticFiles` to follow symlinks [#1683](https://github.com/Kludex/starlette/pull/1683). * Allow `Request.form()` as a context manager [#1903](https://github.com/Kludex/starlette/pull/1903). * Add `size` attribute to `UploadFile` [#1405](https://github.com/Kludex/starlette/pull/1405). * Add `env_prefix` argument to `Config` [#1990](https://github.com/Kludex/starlette/pull/1990). * Add template context processors [#1904](https://github.com/Kludex/starlette/pull/1904). * Support `str` and `datetime` on `expires` parameter on the `Response.set_cookie` method [#1908](https://github.com/Kludex/starlette/pull/1908). ### Changed * Lazily build the middleware stack [#2017](https://github.com/Kludex/starlette/pull/2017). * Make the `file` argument required on `UploadFile` [#1413](https://github.com/Kludex/starlette/pull/1413). * Use debug extension instead of custom response template extension [#1991](https://github.com/Kludex/starlette/pull/1991). ### Fixed * Fix url parsing of ipv6 urls on `URL.replace` [#1965](https://github.com/Kludex/starlette/pull/1965). ## 0.23.1 (December 9, 2022) ### Fixed * Only stop receiving stream on `body_stream` if body is empty on the `BaseHTTPMiddleware` [#1940](https://github.com/Kludex/starlette/pull/1940). ## 0.23.0 (December 5, 2022) ### Added * Add `headers` parameter to the `TestClient` [#1966](https://github.com/Kludex/starlette/pull/1966). ### Deprecated * Deprecate `Starlette` and `Router` decorators [#1897](https://github.com/Kludex/starlette/pull/1897). ### Fixed * Fix bug on `FloatConvertor` regex [#1973](https://github.com/Kludex/starlette/pull/1973). ## 0.22.0 (November 17, 2022) ### Changed * Bypass `GZipMiddleware` when response includes `Content-Encoding` [#1901](https://github.com/Kludex/starlette/pull/1901). ### Fixed * Remove unneeded `unquote()` from query parameters on the `TestClient` [#1953](https://github.com/Kludex/starlette/pull/1953). * Make sure `MutableHeaders._list` is actually a `list` [#1917](https://github.com/Kludex/starlette/pull/1917). * Import compatibility with the next version of `AnyIO` [#1936](https://github.com/Kludex/starlette/pull/1936). ## 0.21.0 (September 26, 2022) This release replaces the underlying HTTP client used on the `TestClient` (`requests` :arrow_right: `httpx`), and as those clients [differ _a bit_ on their API](https://www.python-httpx.org/compatibility/), your test suite will likely break. To make the migration smoother, you can use the [`bump-testclient`](https://github.com/Kludex/bump-testclient) tool. ### Changed * Replace `requests` with `httpx` in `TestClient` [#1376](https://github.com/Kludex/starlette/pull/1376). ### Added * Add `WebSocketException` and support for WebSocket exception handlers [#1263](https://github.com/Kludex/starlette/pull/1263). * Add `middleware` parameter to `Mount` class [#1649](https://github.com/Kludex/starlette/pull/1649). * Officially support Python 3.11 [#1863](https://github.com/Kludex/starlette/pull/1863). * Implement `__repr__` for route classes [#1864](https://github.com/Kludex/starlette/pull/1864). ### Fixed * Fix bug on which `BackgroundTasks` were cancelled when using `BaseHTTPMiddleware` and client disconnected [#1715](https://github.com/Kludex/starlette/pull/1715). ## 0.20.4 (June 28, 2022) ### Fixed * Remove converter from path when generating OpenAPI schema [#1648](https://github.com/Kludex/starlette/pull/1648). ## 0.20.3 (June 10, 2022) ### Fixed * Revert "Allow `StaticFiles` to follow symlinks" [#1681](https://github.com/Kludex/starlette/pull/1681). ## 0.20.2 (June 7, 2022) ### Fixed * Fix regression on route paths with colons [#1675](https://github.com/Kludex/starlette/pull/1675). * Allow `StaticFiles` to follow symlinks [#1337](https://github.com/Kludex/starlette/pull/1377). ## 0.20.1 (May 28, 2022) ### Fixed * Improve detection of async callables [#1444](https://github.com/Kludex/starlette/pull/1444). * Send 400 (Bad Request) when `boundary` is missing [#1617](https://github.com/Kludex/starlette/pull/1617). * Send 400 (Bad Request) when missing "name" field on `Content-Disposition` header [#1643](https://github.com/Kludex/starlette/pull/1643). * Do not send empty data to `StreamingResponse` on `BaseHTTPMiddleware` [#1609](https://github.com/Kludex/starlette/pull/1609). * Add `__bool__` dunder for `Secret` [#1625](https://github.com/Kludex/starlette/pull/1625). ## 0.20.0 (May 3, 2022) ### Removed * Drop Python 3.6 support [#1357](https://github.com/Kludex/starlette/pull/1357) and [#1616](https://github.com/Kludex/starlette/pull/1616). ## 0.19.1 (April 22, 2022) ### Fixed * Fix inference of `Route.name` when created from methods [#1553](https://github.com/Kludex/starlette/pull/1553). * Avoid `TypeError` on `websocket.disconnect` when code is `None` [#1574](https://github.com/Kludex/starlette/pull/1574). ### Deprecated * Deprecate `WS_1004_NO_STATUS_RCVD` and `WS_1005_ABNORMAL_CLOSURE` in favor of `WS_1005_NO_STATUS_RCVD` and `WS_1006_ABNORMAL_CLOSURE`, as the previous constants didn't match the [WebSockets specs](https://www.iana.org/assignments/websocket/websocket.xhtml) [#1580](https://github.com/Kludex/starlette/pull/1580). ## 0.19.0 (March 9, 2022) ### Added * Error handler will always run, even if the error happens on a background task [#761](https://github.com/Kludex/starlette/pull/761). * Add `headers` parameter to `HTTPException` [#1435](https://github.com/Kludex/starlette/pull/1435). * Internal responses with `405` status code insert an `Allow` header, as described by [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-6.5.5) [#1436](https://github.com/Kludex/starlette/pull/1436). * The `content` argument in `JSONResponse` is now required [#1431](https://github.com/Kludex/starlette/pull/1431). * Add custom URL convertor register [#1437](https://github.com/Kludex/starlette/pull/1437). * Add content disposition type parameter to `FileResponse` [#1266](https://github.com/Kludex/starlette/pull/1266). * Add next query param with original request URL in requires decorator [#920](https://github.com/Kludex/starlette/pull/920). * Add `raw_path` to `TestClient` scope [#1445](https://github.com/Kludex/starlette/pull/1445). * Add union operators to `MutableHeaders` [#1240](https://github.com/Kludex/starlette/pull/1240). * Display missing route details on debug page [#1363](https://github.com/Kludex/starlette/pull/1363). * Change `anyio` required version range to `>=3.4.0,<5.0` [#1421](https://github.com/Kludex/starlette/pull/1421) and [#1460](https://github.com/Kludex/starlette/pull/1460). * Add `typing-extensions>=3.10` requirement - used only on lower versions than Python 3.10 [#1475](https://github.com/Kludex/starlette/pull/1475). ### Fixed * Prevent `BaseHTTPMiddleware` from hiding errors of `StreamingResponse` and mounted applications [#1459](https://github.com/Kludex/starlette/pull/1459). * `SessionMiddleware` uses an explicit `path=...`, instead of defaulting to the ASGI 'root_path' [#1512](https://github.com/Kludex/starlette/pull/1512). * `Request.client` is now compliant with the ASGI specifications [#1462](https://github.com/Kludex/starlette/pull/1462). * Raise `KeyError` at early stage for missing boundary [#1349](https://github.com/Kludex/starlette/pull/1349). ### Deprecated * Deprecate WSGIMiddleware in favor of a2wsgi [#1504](https://github.com/Kludex/starlette/pull/1504). * Deprecate `run_until_first_complete` [#1443](https://github.com/Kludex/starlette/pull/1443). ## 0.18.0 (January 23, 2022) ### Added * Change default chunk size from 4Kb to 64Kb on `FileResponse` [#1345](https://github.com/Kludex/starlette/pull/1345). * Add support for `functools.partial` in `WebSocketRoute` [#1356](https://github.com/Kludex/starlette/pull/1356). * Add `StaticFiles` packages with directory [#1350](https://github.com/Kludex/starlette/pull/1350). * Allow environment options in `Jinja2Templates` [#1401](https://github.com/Kludex/starlette/pull/1401). * Allow HEAD method on `HttpEndpoint` [#1346](https://github.com/Kludex/starlette/pull/1346). * Accept additional headers on `websocket.accept` message [#1361](https://github.com/Kludex/starlette/pull/1361) and [#1422](https://github.com/Kludex/starlette/pull/1422). * Add `reason` to `WebSocket` close ASGI event [#1417](https://github.com/Kludex/starlette/pull/1417). * Add headers attribute to `UploadFile` [#1382](https://github.com/Kludex/starlette/pull/1382). * Don't omit `Content-Length` header for `Content-Length: 0` cases [#1395](https://github.com/Kludex/starlette/pull/1395). * Don't set headers for responses with 1xx, 204 and 304 status code [#1397](https://github.com/Kludex/starlette/pull/1397). * `SessionMiddleware.max_age` now accepts `None`, so cookie can last as long as the browser session [#1387](https://github.com/Kludex/starlette/pull/1387). ### Fixed * Tweak `hashlib.md5()` function on `FileResponse`s ETag generation. The parameter [`usedforsecurity`](https://bugs.python.org/issue9216) flag is set to `False`, if the flag is available on the system. This fixes an error raised on systems with [FIPS](https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/FIPS_Mode_-_an_explanation) enabled [#1366](https://github.com/Kludex/starlette/pull/1366) and [#1410](https://github.com/Kludex/starlette/pull/1410). * Fix `path_params` type on `url_path_for()` method i.e. turn `str` into `Any` [#1341](https://github.com/Kludex/starlette/pull/1341). * `Host` now ignores `port` on routing [#1322](https://github.com/Kludex/starlette/pull/1322). ## 0.17.1 (November 17, 2021) ### Fixed * Fix `IndexError` in authentication `requires` when wrapped function arguments are distributed between `*args` and `**kwargs` [#1335](https://github.com/Kludex/starlette/pull/1335). ## 0.17.0 (November 4, 2021) ### Added * `Response.delete_cookie` now accepts the same parameters as `Response.set_cookie` [#1228](https://github.com/Kludex/starlette/pull/1228). * Update the `Jinja2Templates` constructor to allow `PathLike` [#1292](https://github.com/Kludex/starlette/pull/1292). ### Fixed * Fix BadSignature exception handling in SessionMiddleware [#1264](https://github.com/Kludex/starlette/pull/1264). * Change `HTTPConnection.__getitem__` return type from `str` to `typing.Any` [#1118](https://github.com/Kludex/starlette/pull/1118). * Change `ImmutableMultiDict.getlist` return type from `typing.List[str]` to `typing.List[typing.Any]` [#1235](https://github.com/Kludex/starlette/pull/1235). * Handle `OSError` exceptions on `StaticFiles` [#1220](https://github.com/Kludex/starlette/pull/1220). * Fix `StaticFiles` 404.html in HTML mode [#1314](https://github.com/Kludex/starlette/pull/1314). * Prevent anyio.ExceptionGroup in error views under a BaseHTTPMiddleware [#1262](https://github.com/Kludex/starlette/pull/1262). ### Removed * Remove GraphQL support [#1198](https://github.com/Kludex/starlette/pull/1198). ## 0.16.0 (July 19, 2021) ### Added * Added [Encode](https://github.com/sponsors/encode) funding option [#1219](https://github.com/Kludex/starlette/pull/1219) ### Fixed * `starlette.websockets.WebSocket` instances are now hashable and compare by identity [#1039](https://github.com/Kludex/starlette/pull/1039) * A number of fixes related to running task groups in lifespan [#1213](https://github.com/Kludex/starlette/pull/1213), [#1227](https://github.com/Kludex/starlette/pull/1227) ### Deprecated/removed * The method `starlette.templates.Jinja2Templates.get_env` was removed [#1218](https://github.com/Kludex/starlette/pull/1218) * The ClassVar `starlette.testclient.TestClient.async_backend` was removed, the backend is now configured using constructor kwargs [#1211](https://github.com/Kludex/starlette/pull/1211) * Passing an Async Generator Function or a Generator Function to `starlette.routing.Router(lifespan=)` is deprecated. You should wrap your lifespan in `@contextlib.asynccontextmanager`. [#1227](https://github.com/Kludex/starlette/pull/1227) [#1110](https://github.com/Kludex/starlette/pull/1110) ## 0.15.0 (June 23, 2021) This release includes major changes to the low-level asynchronous parts of Starlette. As a result, **Starlette now depends on [AnyIO](https://anyio.readthedocs.io/en/stable/)** and some minor API changes have occurred. Another significant change with this release is the **deprecation of built-in GraphQL support**. ### Added * Starlette now supports [Trio](https://trio.readthedocs.io/en/stable/) as an async runtime via AnyIO - [#1157](https://github.com/Kludex/starlette/pull/1157). * `TestClient.websocket_connect()` now must be used as a context manager. * Initial support for Python 3.10 - [#1201](https://github.com/Kludex/starlette/pull/1201). * The compression level used in `GZipMiddleware` is now adjustable - [#1128](https://github.com/Kludex/starlette/pull/1128). ### Fixed * Several fixes to `CORSMiddleware`. See [#1111](https://github.com/Kludex/starlette/pull/1111), [#1112](https://github.com/Kludex/starlette/pull/1112), [#1113](https://github.com/Kludex/starlette/pull/1113), [#1199](https://github.com/Kludex/starlette/pull/1199). * Improved exception messages in the case of duplicated path parameter names - [#1177](https://github.com/Kludex/starlette/pull/1177). * `RedirectResponse` now uses `quote` instead of `quote_plus` encoding for the `Location` header to better match the behaviour in other frameworks such as Django - [#1164](https://github.com/Kludex/starlette/pull/1164). * Exception causes are now preserved in more cases - [#1158](https://github.com/Kludex/starlette/pull/1158). * Session cookies now use the ASGI root path in the case of mounted applications - [#1147](https://github.com/Kludex/starlette/pull/1147). * Fixed a cache invalidation bug when static files were deleted in certain circumstances - [#1023](https://github.com/Kludex/starlette/pull/1023). * Improved memory usage of `BaseHTTPMiddleware` when handling large responses - [#1012](https://github.com/Kludex/starlette/issues/1012) fixed via #1157 ### Deprecated/removed * Built-in GraphQL support via the `GraphQLApp` class has been deprecated and will be removed in a future release. Please see [#619](https://github.com/Kludex/starlette/issues/619). GraphQL is not supported on Python 3.10. * The `executor` parameter to `GraphQLApp` was removed. Use `executor_class` instead. * The `workers` parameter to `WSGIMiddleware` was removed. This hasn't had any effect since Starlette v0.6.3. ## 0.14.2 (February 2, 2021) ### Fixed * Fixed `ServerErrorMiddleware` compatibility with Python 3.9.1/3.8.7 when debug mode is enabled - [#1132](https://github.com/Kludex/starlette/pull/1132). * Fixed unclosed socket `ResourceWarning`s when using the `TestClient` with WebSocket endpoints - #1132. * Improved detection of `async` endpoints wrapped in `functools.partial` on Python 3.8+ - [#1106](https://github.com/Kludex/starlette/pull/1106). ## 0.14.1 (November 9th, 2020) ### Removed * `UJSONResponse` was removed (this change was intended to be included in 0.14.0). Please see the [documentation](https://starlette.dev/responses/#custom-json-serialization) for how to implement responses using custom JSON serialization - [#1074](https://github.com/Kludex/starlette/pull/1047). ## 0.14.0 (November 8th, 2020) ### Added * Starlette now officially supports Python3.9. * In `StreamingResponse`, allow custom async iterator such as objects from classes implementing `__aiter__`. * Allow usage of `functools.partial` async handlers in Python versions 3.6 and 3.7. * Add 418 I'm A Teapot status code. ### Changed * Create tasks from handler coroutines before sending them to `asyncio.wait`. * Use `format_exception` instead of `format_tb` in `ServerErrorMiddleware`'s `debug` responses. * Be more lenient with handler arguments when using the `requires` decorator. ## 0.13.8 * Revert `Queue(maxsize=1)` fix for `BaseHTTPMiddleware` middleware classes and streaming responses. * The `StaticFiles` constructor now allows `pathlib.Path` in addition to strings for its `directory` argument. ## 0.13.7 * Fix high memory usage when using `BaseHTTPMiddleware` middleware classes and streaming responses. ## 0.13.6 * Fix 404 errors with `StaticFiles`. ## 0.13.5 * Add support for `Starlette(lifespan=...)` functions. * More robust path-traversal check in StaticFiles app. * Fix WSGI PATH_INFO encoding. * RedirectResponse now accepts optional background parameter * Allow path routes to contain regex meta characters * Treat ASGI HTTP 'body' as an optional key. * Don't use thread pooling for writing to in-memory upload files. ## 0.13.0 * Switch to promoting application configuration on init style everywhere. This means dropping the decorator style in favour of declarative routing tables and middleware definitions. ## 0.12.12 * Fix `request.url_for()` for the Mount-within-a-Mount case. ## 0.12.11 * Fix `request.url_for()` when an ASGI `root_path` is being used. ## 0.12.1 * Add `URL.include_query_params(**kwargs)` * Add `URL.replace_query_params(**kwargs)` * Add `URL.remove_query_params(param_names)` * `request.state` properly persisting across middleware. * Added `request.scope` interface. ## 0.12.0 * Switch to ASGI 3.0. * Fixes to CORS middleware. * Add `StaticFiles(html=True)` support. * Fix path quoting in redirect responses. ## 0.11.1 * Add `request.state` interface, for storing arbitrary additional information. * Support disabling GraphiQL with `GraphQLApp(..., graphiql=False)`. ## 0.11.0 * `DatabaseMiddleware` is now dropped in favour of `databases` * Templates are no longer configured on the application instance. Use `templates = Jinja2Templates(directory=...)` and `return templates.TemplateResponse('index.html', {"request": request})` * Schema generation is no longer attached to the application instance. Use `schemas = SchemaGenerator(...)` and `return schemas.OpenAPIResponse(request=request)` * `LifespanMiddleware` is dropped in favor of router-based lifespan handling. * Application instances now accept a `routes` argument, `Starlette(routes=[...])` * Schema generation now includes mounted routes. ## 0.10.6 * Add `Lifespan` routing component. ## 0.10.5 * Ensure `templating` does not strictly require `jinja2` to be installed. ## 0.10.4 * Templates are now configured independently from the application instance. `templates = Jinja2Templates(directory=...)`. Existing API remains in place, but is no longer documented, and will be deprecated in due course. See the template documentation for more details. ## 0.10.3 * Move to independent `databases` package instead of `DatabaseMiddleware`. Existing API remains in place, but is no longer documented, and will be deprecated in due course. ## 0.10.2 * Don't drop explicit port numbers on redirects from `HTTPSRedirectMiddleware`. ## 0.10.1 * Add MySQL database support. * Add host-based routing. ## 0.10.0 * WebSockets now default to sending/receiving JSON over text data frames. Use `.send_json(data, mode="binary")` and `.receive_json(mode="binary")` for binary framing. * `GraphQLApp` now takes an `executor_class` argument, which should be used in preference to the existing `executor` argument. Resolves an issue with async executors being instantiated before the event loop was setup. The `executor` argument is expected to be deprecated in the next median or major release. * Authentication and the `@requires` decorator now support WebSocket endpoints. * `MultiDict` and `ImmutableMultiDict` classes are available in `uvicorn.datastructures`. * `QueryParams` is now instantiated with standard dict-style `*args, **kwargs` arguments. ## 0.9.11 * Session cookies now include browser 'expires', in addition to the existing signed expiry. * `request.form()` now returns a multi-dict interface. * The query parameter multi-dict implementation now mirrors `dict` more correctly for the behavior of `.keys()`, `.values()`, and `.items()` when multiple same-key items occur. * Use `urlsplit` throughout in favor of `urlparse`. ## 0.9.10 * Support `@requires(...)` on class methods. * Apply URL escaping to form data. * Support `HEAD` requests automatically. * Add `await request.is_disconnected()`. * Pass operationName to GraphQL executor. ## 0.9.9 * Add `TemplateResponse`. * Add `CommaSeparatedStrings` datatype. * Add `BackgroundTasks` for multiple tasks. * Common subclass for `Request` and `WebSocket`, to eg. share `session` functionality. * Expose remote address with `request.client`. ## 0.9.8 * Add `request.database.executemany`. ## 0.9.7 * Ensure that `AuthenticationMiddleware` handles lifespan messages correctly. ## 0.9.6 * Add `AuthenticationMiddleware`, and `@requires()` decorator. ## 0.9.5 * Support either `str` or `Secret` for `SessionMiddleware(secret_key=...)`. ## 0.9.4 * Add `config.environ`. * Add `datastructures.Secret`. * Add `datastructures.DatabaseURL`. ## 0.9.3 * Add `config.Config(".env")` ## 0.9.2 * Add optional database support. * Add `request` to GraphQL context. * Hide any password component in `URL.__repr__`. ## 0.9.1 * Handle startup/shutdown errors properly. ## 0.9.0 * `TestClient` can now be used as a context manager, instead of `LifespanContext`. * Lifespan is now handled as middleware. Startup and Shutdown events are visible throughout the middleware stack. ## 0.8.8 * Better support for third-party API schema generators. ## 0.8.7 * Support chunked requests with TestClient. * Cleanup asyncio tasks properly with WSGIMiddleware. * Support using TestClient within endpoints, for service mocking. ## 0.8.6 * Session cookies are now set on the root path. ## 0.8.5 * Support URL convertors. * Support HTTP 304 cache responses from `StaticFiles`. * Resolve character escaping issue with form data. ## 0.8.4 * Default to empty body on responses. ## 0.8.3 * Add 'name' argument to `@app.route()`. * Use 'Host' header for URL reconstruction. ## 0.8.2 ### StaticFiles * StaticFiles no longer reads the file for responses to `HEAD` requests. ## 0.8.1 ### Templating * Add a default templating configuration with Jinja2. Allows the following: ```python app = Starlette(template_directory="templates") @app.route('/') async def homepage(request): # `url_for` is available inside the template. template = app.get_template('index.html') content = template.render(request=request) return HTMLResponse(content) ``` ## 0.8.0 ### Exceptions * Add support for `@app.exception_handler(404)`. * Ensure handled exceptions are not seen as errors by the middleware stack. ### SessionMiddleware * Add `max_age`, and use timestamp-signed cookies. Defaults to two weeks. ### Cookies * Ensure cookies are strictly HTTP correct. ### StaticFiles * Check directory exists on instantiation. ## 0.7.4 ### Concurrency * Add `starlette.concurrency.run_in_threadpool`. Now handles `contextvar` support. ## 0.7.3 ### Routing * Add `name=` support to `app.mount()`. This allows eg: `app.mount('/static', StaticFiles(directory='static'), name='static')`. ## 0.7.2 ### Middleware * Add support for `@app.middleware("http")` decorator. ### Routing * Add "endpoint" to ASGI scope. ## 0.7.1 ### Debug tracebacks * Improve debug traceback information & styling. ### URL routing * Support mounted URL lookups with "path=", eg. `url_for('static', path=...)`. * Support nested URL lookups, eg. `url_for('admin:user', username=...)`. * Add redirect slashes support. * Add www redirect support. ### Background tasks * Add background task support to `FileResponse` and `StreamingResponse`. ## 0.7.0 ### API Schema support * Add `app.schema_generator = SchemaGenerator(...)`. * Add `app.schema` property. * Add `OpenAPIResponse(...)`. ### GraphQL routing * Drop `app.add_graphql_route("/", ...)` in favor of more consistent `app.add_route("/", GraphQLApp(...))`. ## 0.6.3 ### Routing API * Support routing to methods. * Ensure `url_path_for` works with Mount('/{some_path_params}'). * Fix Router(default=) argument. * Support repeated paths, like: `@app.route("/", methods=["GET"])`, `@app.route("/", methods=["POST"])` * Use the default ThreadPoolExecutor for all sync endpoints. ## 0.6.2 ### SessionMiddleware Added support for `request.session`, with `SessionMiddleware`. ## 0.6.1 ### BaseHTTPMiddleware Added support for `BaseHTTPMiddleware`, which provides a standard request/response interface over a regular ASGI middleware. This means you can write ASGI middleware while still working at a request/response level, rather than handling ASGI messages directly. ```python from starlette.applications import Starlette from starlette.middleware.base import BaseHTTPMiddleware class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) response.headers['Custom-Header'] = 'Example' return response app = Starlette() app.add_middleware(CustomMiddleware) ``` ## 0.6.0 ### request.path_params The biggest change in 0.6 is that endpoint signatures are no longer: ```python async def func(request: Request, **kwargs) -> Response ``` Instead we just use: ```python async def func(request: Request) -> Response ``` The path parameters are available on the request as `request.path_params`. This is different to most Python webframeworks, but I think it actually ends up being much more nicely consistent all the way through. ### request.url_for() Request and WebSocketSession now support URL reversing with `request.url_for(name, **path_params)`. This method returns a fully qualified `URL` instance. The URL instance is a string-like object. ### app.url_path_for() Applications now support URL path reversing with `app.url_path_for(name, **path_params)`. This method returns a `URL` instance with the path and scheme set. The URL instance is a string-like object, and will return only the path if coerced to a string. ### app.routes Applications now support a `.routes` parameter, which returns a list of `[Route|WebSocketRoute|Mount]`. ### Route, WebSocketRoute, Mount The low level components to `Router` now match the `@app.route()`, `@app.websocket_route()`, and `app.mount()` signatures. ================================================ FILE: docs/requests.md ================================================ Starlette includes a `Request` class that gives you a nicer interface onto the incoming request, rather than accessing the ASGI scope and receive channel directly. ### Request Signature: `Request(scope, receive=None)` ```python from starlette.requests import Request from starlette.responses import Response async def app(scope, receive, send): assert scope['type'] == 'http' request = Request(scope, receive) content = '%s %s' % (request.method, request.url.path) response = Response(content, media_type='text/plain') await response(scope, receive, send) ``` Requests present a mapping interface, so you can use them in the same way as a `scope`. For instance: `request['path']` will return the ASGI path. If you don't need to access the request body you can instantiate a request without providing an argument to `receive`. #### Method The request method is accessed as `request.method`. #### URL The request URL is accessed as `request.url`. The property is a string-like object that exposes all the components that can be parsed out of the URL. For example: `request.url.path`, `request.url.port`, `request.url.scheme`. #### Headers Headers are exposed as an immutable, case-insensitive, multi-dict. For example: `request.headers['content-type']` #### Query Parameters Query parameters are exposed as an immutable multi-dict. For example: `request.query_params['search']` #### Path Parameters Router path parameters are exposed as a dictionary interface. For example: `request.path_params['username']` #### Client Address The client's remote address is exposed as a named two-tuple `request.client` (or `None`). The hostname or IP address: `request.client.host` The port number from which the client is connecting: `request.client.port` #### Cookies Cookies are exposed as a regular dictionary interface. For example: `request.cookies.get('mycookie')` Cookies are ignored in case of an invalid cookie. (RFC2109) #### Body There are a few different interfaces for returning the body of the request: The request body as bytes: `await request.body()` The request body, parsed as form data or multipart: `async with request.form() as form:` The request body, parsed as JSON: `await request.json()` You can also access the request body as a stream, using the `async for` syntax: ```python from starlette.requests import Request from starlette.responses import Response async def app(scope, receive, send): assert scope['type'] == 'http' request = Request(scope, receive) body = b'' async for chunk in request.stream(): body += chunk response = Response(body, media_type='text/plain') await response(scope, receive, send) ``` If you access `.stream()` then the byte chunks are provided without storing the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()` will raise an error. In some cases such as long-polling, or streaming responses you might need to determine if the client has dropped the connection. You can determine this state with `disconnected = await request.is_disconnected()`. #### Request Files Request files are normally sent as multipart form data (`multipart/form-data`). Signature: `request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024)` You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`; and part size using `max_part_size`: ```python async with request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024): ... ``` !!! info These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields. When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`. `UploadFile` has the following attributes: * `filename`: An `str` with the original file name that was uploaded or `None` if its not available (e.g. `myimage.jpg`). * `content_type`: An `str` with the content type (MIME type / media type) or `None` if it's not available (e.g. `image/jpeg`). * `file`: A `SpooledTemporaryFile` (a file-like object). This is the actual Python file that you can pass directly to other functions or libraries that expect a "file-like" object. * `headers`: A `Headers` object. Often this will only be the `Content-Type` header, but if additional headers were included in the multipart field they will be included here. Note that these headers have no relationship with the headers in `Request.headers`. * `size`: An `int` with uploaded file's size in bytes. This value is calculated from request's contents, making it better choice to find uploaded file's size than `Content-Length` header. `None` if not set. `UploadFile` has the following `async` methods. They all call the corresponding file methods underneath (using the internal `SpooledTemporaryFile`). * `async write(data)`: Writes `data` (`bytes`) to the file. * `async read(size)`: Reads `size` (`int`) bytes of the file. * `async seek(offset)`: Goes to the byte position `offset` (`int`) in the file. * E.g., `await myfile.seek(0)` would go to the start of the file. * `async close()`: Closes the file. As all these methods are `async` methods, you need to "await" them. For example, you can get the file name and the contents with: ```python async with request.form() as form: filename = form["upload_file"].filename contents = await form["upload_file"].read() ``` !!! info As settled in [RFC-7578: 4.2](https://www.ietf.org/rfc/rfc7578.txt), form-data content part that contains file assumed to have `name` and `filename` fields in `Content-Disposition` header: `Content-Disposition: form-data; name="user"; filename="somefile"`. Though `filename` field is optional according to RFC-7578, it helps Starlette to differentiate which data should be treated as file. If `filename` field was supplied, `UploadFile` object will be created to access underlying file, otherwise form-data part will be parsed and available as a raw string. #### Application The originating Starlette application can be accessed via `request.app`. #### Other state If you want to store additional information on the request you can do so using `request.state`. For example: `request.state.time_started = time.time()` ================================================ FILE: docs/responses.md ================================================ Starlette includes a few response classes that handle sending back the appropriate ASGI messages on the `send` channel. ### Response Signature: `Response(content, status_code=200, headers=None, media_type=None)` * `content` - A string or bytestring. * `status_code` - An integer HTTP status code. * `headers` - A dictionary of strings. * `media_type` - A string giving the media type. eg. "text/html" Starlette will automatically include a Content-Length header. It will also include a Content-Type header, based on the media_type and appending a charset for text types, unless a charset has already been specified in the `media_type`. Once you've instantiated a response, you can send it by calling it as an ASGI application instance. ```python from starlette.responses import Response async def app(scope, receive, send): assert scope['type'] == 'http' response = Response('Hello, world!', media_type='text/plain') await response(scope, receive, send) ``` #### Set Cookie Starlette provides a `set_cookie` method to allow you to set cookies on the response object. Signature: `Response.set_cookie(key, value, max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax", partitioned=False)` * `key` - A string that will be the cookie's key. * `value` - A string that will be the cookie's value. * `max_age` - An integer that defines the lifetime of the cookie in seconds. A negative integer or a value of `0` will discard the cookie immediately. `Optional` * `expires` - Either an integer that defines the number of seconds until the cookie expires, or a datetime. `Optional` * `path` - A string that specifies the subset of routes to which the cookie will apply. `Optional` * `domain` - A string that specifies the domain for which the cookie is valid. `Optional` * `secure` - A bool indicating that the cookie will only be sent to the server if request is made using SSL and the HTTPS protocol. `Optional` * `httponly` - A bool indicating that the cookie cannot be accessed via JavaScript through `Document.cookie` property, the `XMLHttpRequest` or `Request` APIs. `Optional` * `samesite` - A string that specifies the samesite strategy for the cookie. Valid values are `'lax'`, `'strict'` and `'none'`. Defaults to `'lax'`. `Optional` * `partitioned` - A bool that indicates to user agents that these cross-site cookies should only be available in the same top-level context that the cookie was first set in. Only available for Python 3.14+, otherwise an error will be raised. `Optional` #### Delete Cookie Conversely, Starlette also provides a `delete_cookie` method to manually expire a set cookie. Signature: `Response.delete_cookie(key, path='/', domain=None)` ### HTMLResponse Takes some text or bytes and returns an HTML response. ```python from starlette.responses import HTMLResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = HTMLResponse('

    Hello, world!

    ') await response(scope, receive, send) ``` ### PlainTextResponse Takes some text or bytes and returns a plain text response. ```python from starlette.responses import PlainTextResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` ### JSONResponse Takes some data and returns an `application/json` encoded response. ```python from starlette.responses import JSONResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = JSONResponse({'hello': 'world'}) await response(scope, receive, send) ``` #### Custom JSON serialization If you need fine-grained control over JSON serialization, you can subclass `JSONResponse` and override the `render` method. For example, if you wanted to use a third-party JSON library such as [orjson](https://pypi.org/project/orjson/): ```python from typing import Any import orjson from starlette.responses import JSONResponse class OrjsonResponse(JSONResponse): def render(self, content: Any) -> bytes: return orjson.dumps(content) ``` In general you *probably* want to stick with `JSONResponse` by default unless you are micro-optimising a particular endpoint or need to serialize non-standard object types. ### RedirectResponse Returns an HTTP redirect. Uses a 307 status code by default. ```python from starlette.responses import PlainTextResponse, RedirectResponse async def app(scope, receive, send): assert scope['type'] == 'http' if scope['path'] != '/': response = RedirectResponse(url='/') else: response = PlainTextResponse('Hello, world!') await response(scope, receive, send) ``` ### StreamingResponse Takes an async generator or a normal generator/iterator and streams the response body. ```python from starlette.responses import StreamingResponse import asyncio async def slow_numbers(minimum, maximum): yield '' async def app(scope, receive, send): assert scope['type'] == 'http' generator = slow_numbers(1, 10) response = StreamingResponse(generator, media_type='text/html') await response(scope, receive, send) ``` Have in mind that file-like objects (like those created by `open()`) are normal iterators. So, you can return them directly in a `StreamingResponse`. ### FileResponse Asynchronously streams a file as the response. Takes a different set of arguments to instantiate than the other response types: * `path` - The filepath to the file to stream. * `headers` - Any custom headers to include, as a dictionary. * `media_type` - A string giving the media type. If unset, the filename or path will be used to infer a media type. * `filename` - If set, this will be included in the response `Content-Disposition`. * `content_disposition_type` - will be included in the response `Content-Disposition`. Can be set to "attachment" (default) or "inline". File responses will include appropriate `Content-Length`, `Last-Modified` and `ETag` headers. ```python from starlette.responses import FileResponse async def app(scope, receive, send): assert scope['type'] == 'http' response = FileResponse('statics/favicon.ico') await response(scope, receive, send) ``` File responses also supports [HTTP range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests). The `Accept-Ranges: bytes` header will be included in the response if the file exists. For now, only the `bytes` range unit is supported. If the request includes a `Range` header, and the file exists, the response will be a `206 Partial Content` response with the requested range of bytes. If the range is invalid, the response will be a `416 Range Not Satisfiable` response. ## Third party responses #### [EventSourceResponse](https://github.com/sysid/sse-starlette) A response class that implements [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html). It enables event streaming from the server to the client without the complexity of websockets. ================================================ FILE: docs/routing.md ================================================ ## HTTP Routing Starlette has a simple but capable request routing system. A routing table is defined as a list of routes, and passed when instantiating the application. ```python from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.routing import Route async def homepage(request): return PlainTextResponse("Homepage") async def about(request): return PlainTextResponse("About") routes = [ Route("/", endpoint=homepage), Route("/about", endpoint=about), ] app = Starlette(routes=routes) ``` The `endpoint` argument can be one of: * A regular function or async function, which accepts a single `request` argument and which should return a response. * A class that implements the ASGI interface, such as Starlette's [HTTPEndpoint](endpoints.md#httpendpoint). ## Path Parameters Paths can use URI templating style to capture path components. ```python Route('/users/{username}', user) ``` By default this will capture characters up to the end of the path or the next `/`. You can use convertors to modify what is captured. The available convertors are: * `str` returns a string, and is the default. * `int` returns a Python integer. * `float` returns a Python float. * `uuid` return a Python `uuid.UUID` instance. * `path` returns the rest of the path, including any additional `/` characters. Convertors are used by prefixing them with a colon, like so: ```python Route('/users/{user_id:int}', user) Route('/floating-point/{number:float}', floating_point) Route('/uploaded/{rest_of_path:path}', uploaded) ``` If you need a different converter that is not defined, you can create your own. See below an example on how to create a `datetime` convertor, and how to register it: ```python from datetime import datetime from starlette.convertors import Convertor, register_url_convertor class DateTimeConvertor(Convertor): regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?" def convert(self, value: str) -> datetime: return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S") def to_string(self, value: datetime) -> str: return value.strftime("%Y-%m-%dT%H:%M:%S") register_url_convertor("datetime", DateTimeConvertor()) ``` After registering it, you'll be able to use it as: ```python Route('/history/{date:datetime}', history) ``` Path parameters are made available in the request, as the `request.path_params` dictionary. ```python async def user(request): user_id = request.path_params['user_id'] ... ``` ## Handling HTTP methods Routes can also specify which HTTP methods are handled by an endpoint: ```python Route('/users/{user_id:int}', user, methods=["GET", "POST"]) ``` By default function endpoints will only accept `GET` requests, unless specified. ## Submounting routes In large applications you might find that you want to break out parts of the routing table, based on a common path prefix. ```python routes = [ Route('/', homepage), Mount('/users', routes=[ Route('/', users, methods=['GET', 'POST']), Route('/{username}', user), ]) ] ``` This style allows you to define different subsets of the routing table in different parts of your project. ```python from myproject import users, auth routes = [ Route('/', homepage), Mount('/users', routes=users.routes), Mount('/auth', routes=auth.routes), ] ``` You can also use mounting to include sub-applications within your Starlette application. For example... ```python # This is a standalone static files server: app = StaticFiles(directory="static") # This is a static files server mounted within a Starlette application, # underneath the "/static" path. routes = [ ... Mount("/static", app=StaticFiles(directory="static"), name="static") ] app = Starlette(routes=routes) ``` ## Reverse URL lookups You'll often want to be able to generate the URL for a particular route, such as in cases where you need to return a redirect response. * Signature: `url_for(name, **path_params) -> URL` ```python routes = [ Route("/", homepage, name="homepage") ] # We can use the following to return a URL... url = request.url_for("homepage") ``` URL lookups can include path parameters... ```python routes = [ Route("/users/{username}", user, name="user_detail") ] # We can use the following to return a URL... url = request.url_for("user_detail", username=...) ``` If a `Mount` includes a `name`, then submounts should use a `{prefix}:{name}` style for reverse URL lookups. ```python routes = [ Mount("/users", name="users", routes=[ Route("/", user, name="user_list"), Route("/{username}", user, name="user_detail") ]) ] # We can use the following to return URLs... url = request.url_for("users:user_list") url = request.url_for("users:user_detail", username=...) ``` Mounted applications may include a `path=...` parameter. ```python routes = [ ... Mount("/static", app=StaticFiles(directory="static"), name="static") ] # We can use the following to return URLs... url = request.url_for("static", path="/css/base.css") ``` For cases where there is no `request` instance, you can make reverse lookups against the application, although these will only return the URL path. ```python url = app.url_path_for("user_detail", username=...) ``` ## Host-based routing If you want to use different routes for the same path based on the `Host` header. Note that port is removed from the `Host` header when matching. For example, `Host (host='example.org:3600', ...)` will be processed even if the `Host` header contains or does not contain a port other than `3600` (`example.org:5600`, `example.org`). Therefore, you can specify the port if you need it for use in `url_for`. There are several ways to connect host-based routes to your application ```python site = Router() # Use eg. `@site.route()` to configure this. api = Router() # Use eg. `@api.route()` to configure this. news = Router() # Use eg. `@news.route()` to configure this. routes = [ Host('api.example.org', api, name="site_api") ] app = Starlette(routes=routes) app.host('www.example.org', site, name="main_site") news_host = Host('news.example.org', news) app.router.routes.append(news_host) ``` URL lookups can include host parameters just like path parameters ```python routes = [ Host("{subdomain}.example.org", name="sub", app=Router(routes=[ Mount("/users", name="users", routes=[ Route("/", user, name="user_list"), Route("/{username}", user, name="user_detail") ]) ])) ] ... url = request.url_for("sub:users:user_detail", username=..., subdomain=...) url = request.url_for("sub:users:user_list", subdomain=...) ``` ## Route priority Incoming paths are matched against each `Route` in order. In cases where more that one route could match an incoming path, you should take care to ensure that more specific routes are listed before general cases. For example: ```python # Don't do this: `/users/me` will never match incoming requests. routes = [ Route('/users/{username}', user), Route('/users/me', current_user), ] # Do this: `/users/me` is tested first. routes = [ Route('/users/me', current_user), Route('/users/{username}', user), ] ``` ## Working with Router instances If you're working at a low-level you might want to use a plain `Router` instance, rather that creating a `Starlette` application. This gives you a lightweight ASGI application that just provides the application routing, without wrapping it up in any middleware. ```python app = Router(routes=[ Route('/', homepage), Mount('/users', routes=[ Route('/', users, methods=['GET', 'POST']), Route('/{username}', user), ]) ]) ``` ## WebSocket Routing When working with WebSocket endpoints, you should use `WebSocketRoute` instead of the usual `Route`. Path parameters, and reverse URL lookups for `WebSocketRoute` work the same as HTTP `Route`, which can be found in the HTTP [Route](#http-routing) section above. ```python from starlette.applications import Starlette from starlette.routing import WebSocketRoute async def websocket_index(websocket): await websocket.accept() await websocket.send_text("Hello, websocket!") await websocket.close() async def websocket_user(websocket): name = websocket.path_params["name"] await websocket.accept() await websocket.send_text(f"Hello, {name}") await websocket.close() routes = [ WebSocketRoute("/", endpoint=websocket_index), WebSocketRoute("/{name}", endpoint=websocket_user), ] app = Starlette(routes=routes) ``` The `endpoint` argument can be one of: * An async function, which accepts a single `websocket` argument. * A class that implements the ASGI interface, such as Starlette's [WebSocketEndpoint](endpoints.md#websocketendpoint). ================================================ FILE: docs/schemas.md ================================================ Starlette supports generating API schemas, such as the widely used [OpenAPI specification][openapi]. (Formerly known as "Swagger".) Schema generation works by inspecting the routes on the application through `app.routes`, and using the docstrings or other attributes on the endpoints in order to determine a complete API schema. Starlette is not tied to any particular schema generation or validation tooling, but includes a simple implementation that generates OpenAPI schemas based on the docstrings. ```python from starlette.applications import Starlette from starlette.routing import Route from starlette.schemas import SchemaGenerator schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} ) def list_users(request): """ responses: 200: description: A list of users. examples: [{"username": "tom"}, {"username": "lucy"}] """ raise NotImplementedError() def create_user(request): """ responses: 200: description: A user. examples: {"username": "tom"} """ raise NotImplementedError() def openapi_schema(request): return schemas.OpenAPIResponse(request=request) routes = [ Route("/users", endpoint=list_users, methods=["GET"]), Route("/users", endpoint=create_user, methods=["POST"]), Route("/schema", endpoint=openapi_schema, include_in_schema=False) ] app = Starlette(routes=routes) ``` We can now access an OpenAPI schema at the "/schema" endpoint. You can generate the API Schema directly with `.get_schema(routes)`: ```python schema = schemas.get_schema(routes=app.routes) assert schema == { "openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}, "paths": { "/users": { "get": { "responses": { 200: { "description": "A list of users.", "examples": [{"username": "tom"}, {"username": "lucy"}], } } }, "post": { "responses": { 200: {"description": "A user.", "examples": {"username": "tom"}} } }, }, }, } ``` You might also want to be able to print out the API schema, so that you can use tooling such as generating API documentation. ```python if __name__ == '__main__': assert sys.argv[-1] in ("run", "schema"), "Usage: example.py [run|schema]" if sys.argv[-1] == "run": uvicorn.run("example:app", host='0.0.0.0', port=8000) elif sys.argv[-1] == "schema": schema = schemas.get_schema(routes=app.routes) print(yaml.dump(schema, default_flow_style=False)) ``` ### Third party packages #### [starlette-apispec][starlette-apispec] Easy APISpec integration for Starlette, which supports some object serialization libraries. [openapi]: https://github.com/OAI/OpenAPI-Specification [starlette-apispec]: https://github.com/Woile/starlette-apispec ================================================ FILE: docs/server-push.md ================================================ Starlette includes support for HTTP/2 and HTTP/3 server push, making it possible to push resources to the client to speed up page load times. ### `Request.send_push_promise` Used to initiate a server push for a resource. If server push is not available this method does nothing. Signature: `send_push_promise(path)` * `path` - A string denoting the path of the resource. ```python from starlette.applications import Starlette from starlette.responses import HTMLResponse from starlette.routing import Route, Mount from starlette.staticfiles import StaticFiles async def homepage(request): """ Homepage which uses server push to deliver the stylesheet. """ await request.send_push_promise("/static/style.css") return HTMLResponse( '' ) routes = [ Route("/", endpoint=homepage), Mount("/static", StaticFiles(directory="static"), name="static") ] app = Starlette(routes=routes) ``` ================================================ FILE: docs/staticfiles.md ================================================ Starlette also includes a `StaticFiles` class for serving files in a given directory: ### StaticFiles Signature: `StaticFiles(directory=None, packages=None, html=False, check_dir=True, follow_symlink=False)` * `directory` - A string or [os.PathLike][pathlike] denoting a directory path. * `packages` - A list of strings or list of tuples of strings of python packages. * `html` - Run in HTML mode. Automatically loads `index.html` for directories if such file exist. * `check_dir` - Ensure that the directory exists upon instantiation. Defaults to `True`. * `follow_symlink` - A boolean indicating if symbolic links for files and directories should be followed. Defaults to `False`. You can combine this ASGI application with Starlette's routing to provide comprehensive static file serving. ```python from starlette.applications import Starlette from starlette.routing import Mount from starlette.staticfiles import StaticFiles routes = [ ... Mount('/static', app=StaticFiles(directory='static'), name="static"), ] app = Starlette(routes=routes) ``` Static files will respond with "404 Not found" or "405 Method not allowed" responses for requests which do not match. In HTML mode if `404.html` file exists it will be shown as 404 response. The `packages` option can be used to include "static" directories contained within a python package. The Python "bootstrap4" package is an example of this. ```python from starlette.applications import Starlette from starlette.routing import Mount from starlette.staticfiles import StaticFiles routes=[ ... Mount('/static', app=StaticFiles(directory='static', packages=['bootstrap4']), name="static"), ] app = Starlette(routes=routes) ``` By default `StaticFiles` will look for `statics` directory in each package, you can change the default directory by specifying a tuple of strings. ```python routes=[ ... Mount('/static', app=StaticFiles(packages=[('bootstrap4', 'static')]), name="static"), ] ``` You may prefer to include static files directly inside the "static" directory rather than using Python packaging to include static files, but it can be useful for bundling up reusable components. [pathlike]: https://docs.python.org/3/library/os.html#os.PathLike ================================================ FILE: docs/templates.md ================================================ Starlette is not _strictly_ coupled to any particular templating engine, but Jinja2 provides an excellent choice. ??? abstract "API Reference" ::: starlette.templating.Jinja2Templates options: parameter_headings: false show_root_heading: true heading_level: 3 filters: - "__init__" Starlette provides a simple way to get `jinja2` configured. This is probably what you want to use by default. ```python from starlette.applications import Starlette from starlette.routing import Route, Mount from starlette.templating import Jinja2Templates from starlette.staticfiles import StaticFiles templates = Jinja2Templates(directory='templates') async def homepage(request): return templates.TemplateResponse(request, 'index.html') routes = [ Route('/', endpoint=homepage), Mount('/static', StaticFiles(directory='static'), name='static') ] app = Starlette(debug=True, routes=routes) ``` Note that the incoming `request` instance must be included as part of the template context. The Jinja2 template context will automatically include a `url_for` function, so we can correctly hyperlink to other pages within the application. For example, we can link to static files from within our HTML templates: ```html ``` If you want to use [custom filters][jinja2], you will need to update the `env` property of `Jinja2Templates`: ```python from commonmark import commonmark from starlette.templating import Jinja2Templates def marked_filter(text): return commonmark(text) templates = Jinja2Templates(directory='templates') templates.env.filters['marked'] = marked_filter ``` ## Using custom jinja2.Environment instance Starlette also accepts a preconfigured [`jinja2.Environment`](https://jinja.palletsprojects.com/en/3.0.x/api/#api) instance. ```python import jinja2 from starlette.templating import Jinja2Templates env = jinja2.Environment(...) templates = Jinja2Templates(env=env) ``` ## Autoescape When using the `directory` argument, Starlette enables autoescape by default for `.html`, `.htm`, and `.xml` templates using [`jinja2.select_autoescape()`](https://jinja.palletsprojects.com/en/stable/api/#jinja2.select_autoescape). This protects against Cross-Site Scripting (XSS) vulnerabilities by escaping user-provided content before rendering it in the template. For example, if a user submits `` as their name, it will be rendered as `<script>alert('XSS')</script>` instead of being executed as JavaScript. ## Context processors A context processor is a function that returns a dictionary to be merged into a template context. Every function takes only one argument `request` and must return a dictionary to add to the context. A common use case of template processors is to extend the template context with shared variables. ```python import typing from starlette.requests import Request def app_context(request: Request) -> typing.Dict[str, typing.Any]: return {'app': request.app} ``` ### Registering context templates Pass context processors to `context_processors` argument of the `Jinja2Templates` class. ```python import typing from starlette.requests import Request from starlette.templating import Jinja2Templates def app_context(request: Request) -> typing.Dict[str, typing.Any]: return {'app': request.app} templates = Jinja2Templates( directory='templates', context_processors=[app_context] ) ``` !!! info Asynchronous functions as context processors are not supported. ## Testing template responses When using the test client, template responses include `.template` and `.context` attributes. ```python from starlette.testclient import TestClient def test_homepage(): client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert response.template.name == 'index.html' assert "request" in response.context ``` ## Asynchronous template rendering Jinja2 supports async template rendering, however as a general rule we'd recommend that you keep your templates free from logic that invokes database lookups, or other I/O operations. Instead we'd recommend that you ensure that your endpoints perform all I/O, for example, strictly evaluate any database queries within the view and include the final results in the context. [jinja2]: https://jinja.palletsprojects.com/en/3.0.x/api/?highlight=environment#writing-filters [pathlike]: https://docs.python.org/3/library/os.html#os.PathLike ================================================ FILE: docs/testclient.md ================================================ ??? abstract "API Reference" ::: starlette.testclient.TestClient options: parameter_headings: false show_bases: false show_root_heading: true heading_level: 3 filters: - "__init__" The test client allows you to make requests against your ASGI application, using the `httpx` library. ```python from starlette.responses import HTMLResponse from starlette.testclient import TestClient async def app(scope, receive, send): assert scope['type'] == 'http' response = HTMLResponse('Hello, world!') await response(scope, receive, send) def test_app(): client = TestClient(app) response = client.get('/') assert response.status_code == 200 ``` The test client exposes the same interface as any other `httpx` session. In particular, note that the calls to make a request are just standard function calls, not awaitables. You can use any of `httpx` standard API, such as authentication, session cookies handling, or file uploads. For example, to set headers on the TestClient you can do: ```python client = TestClient(app) # Set headers on the client for future requests client.headers = {"Authorization": "..."} response = client.get("/") # Set headers for each request separately response = client.get("/", headers={"Authorization": "..."}) ``` And for example to send files with the TestClient: ```python client = TestClient(app) # Send a single file with open("example.txt", "rb") as f: response = client.post("/form", files={"file": f}) # Send multiple files with open("example.txt", "rb") as f1: with open("example.png", "rb") as f2: files = {"file1": f1, "file2": ("filename", f2, "image/png")} response = client.post("/form", files=files) ``` For more information you can check the `httpx` [documentation](https://www.python-httpx.org/advanced/). By default the `TestClient` will raise any exceptions that occur in the application. Occasionally you might want to test the content of 500 error responses, rather than allowing client to raise the server exception. In this case you should use `client = TestClient(app, raise_server_exceptions=False)`. !!! note If you want the `TestClient` to run the `lifespan` handler, you will need to use the `TestClient` as a context manager. It will not be triggered when the `TestClient` is instantiated. You can learn more about it [here](lifespan.md#running-lifespan-in-tests). ### Change client address By default, the TestClient will set the client host to `"testserver"` and the port to `50000`. You can change the client address by setting the `client` attribute of the `TestClient` instance: ```python client = TestClient(app, client=('localhost', 8000)) ``` ### Selecting the Async backend `TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary). These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) for more information about the accepted backend options. By default, `asyncio` is used with default options. To run `Trio`, pass `backend="trio"`. For example: ```python def test_app() with TestClient(app, backend="trio") as client: ... ``` To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example: ```python def test_app() with TestClient(app, backend_options={"use_uvloop": True}) as client: ... ``` ### Testing WebSocket sessions You can also test websocket sessions with the test client. The `httpx` library will be used to build the initial handshake, meaning you can use the same authentication options and other headers between both http and websocket testing. ```python from starlette.testclient import TestClient from starlette.websockets import WebSocket async def app(scope, receive, send): assert scope['type'] == 'websocket' websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_text('Hello, world!') await websocket.close() def test_app(): client = TestClient(app) with client.websocket_connect('/') as websocket: data = websocket.receive_text() assert data == 'Hello, world!' ``` The operations on session are standard function calls, not awaitables. It's important to use the session within a context-managed `with` block. This ensure that the background thread on which the ASGI application is properly terminated, and that any exceptions that occur within the application are always raised by the test client. #### Establishing a test session * `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `httpx.get()`. May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. `websocket_connect()` must be used as a context manager (in a `with` block). !!! note The `params` argument is not supported by `websocket_connect`. If you need to pass query arguments, hard code it directly in the URL. ```python with client.websocket_connect('/path?foo=bar') as websocket: ... ``` #### Sending data * `.send_text(data)` - Send the given text to the application. * `.send_bytes(data)` - Send the given bytes to the application. * `.send_json(data, mode="text")` - Send the given data to the application. Use `mode="binary"` to send JSON over binary data frames. #### Receiving data * `.receive_text()` - Wait for incoming text sent by the application and return it. * `.receive_bytes()` - Wait for incoming bytestring sent by the application and return it. * `.receive_json(mode="text")` - Wait for incoming json data sent by the application and return it. Use `mode="binary"` to receive JSON over binary data frames. May raise `starlette.websockets.WebSocketDisconnect`. #### Closing the connection * `.close(code=1000)` - Perform a client-side close of the websocket connection. ### Asynchronous tests Sometimes you will want to do async things outside of your application. For example, you might want to check the state of your database after calling your app using your existing async database client/infrastructure. For these situations, using `TestClient` is difficult because it creates it's own event loop and async resources (like a database connection) often cannot be shared across event loops. The simplest way to work around this is to just make your entire test async and use an async client, like [httpx.AsyncClient]. Here is an example of such a test: ```python from httpx import AsyncClient, ASGITransport from starlette.applications import Starlette from starlette.routing import Route from starlette.requests import Request from starlette.responses import PlainTextResponse def hello(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello World!") app = Starlette(routes=[Route("/", hello)]) # if you're using pytest, you'll need to to add an async marker like: # @pytest.mark.anyio # using https://github.com/agronholm/anyio # or install and configure pytest-asyncio (https://github.com/pytest-dev/pytest-asyncio) async def test_app() -> None: # note: you _must_ set `base_url` for relative urls like "/" to work transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as client: r = await client.get("/") assert r.status_code == 200 assert r.text == "Hello World!" ``` [httpx.AsyncClient]: https://www.python-httpx.org/advanced/#calling-into-python-web-apps ================================================ FILE: docs/third-party-packages.md ================================================ Starlette has a rapidly growing community of developers, building tools that integrate into Starlette, tools that depend on Starlette, etc. Here are some of those third party packages: ## Plugins ### Apitally GitHub | Documentation Analytics, request logging and monitoring for REST APIs built with Starlette (and other frameworks). ### Authlib GitHub | Documentation The ultimate Python library in building OAuth and OpenID Connect clients and servers. Check out how to integrate with [Starlette](https://docs.authlib.org/en/latest/client/starlette.html). ### ChannelBox GitHub Repository ChannelBox is a lightweight solution for WebSocket broadcasting in ASGI applications. It allows sending messages to named WebSocket channel groups and integrates with Starlette and FastAPI. ### Imia GitHub An authentication framework for Starlette with pluggable authenticators and login/logout flow. ### Mangum GitHub Serverless ASGI adapter for AWS Lambda & API Gateway. ### Nejma GitHub Manage and send messages to groups of channels using websockets. Checkout nejma-chat, a simple chat application built using `nejma` and `starlette`. ### Scout APM GitHub An APM (Application Performance Monitoring) solution that can instrument your application to find performance bottlenecks. ### SpecTree GitHub Generate OpenAPI spec document and validate request & response with Python annotations. Less boilerplate code(no need for YAML). ### Starlette APISpec GitHub Simple APISpec integration for Starlette. Document your REST API built with Starlette by declaring OpenAPI (Swagger) schemas in YAML format in your endpoint's docstrings. ### Starlette Compress GitHub Starlette-Compress is a fast and simple middleware for compressing responses in Starlette. It adds ZStd, Brotli, and GZip compression support with sensible default configuration. ### Starlette Context GitHub Middleware for Starlette that allows you to store and access the context data of a request. Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id. ### Starlette Cramjam GitHub A Starlette middleware that allows **brotli**, **gzip** and **deflate** compression algorithm with a minimal requirements. ### Starlette OAuth2 API GitLab A starlette middleware to add authentication and authorization through JWTs. It relies solely on an auth provider to issue access and/or id tokens to clients. ### Starlette Prometheus GitHub A plugin for providing an endpoint that exposes [Prometheus](https://prometheus.io/) metrics based on its [official python client](https://github.com/prometheus/client_python). ### Starlette WTF GitHub A simple tool for integrating Starlette and WTForms. It is modeled on the excellent Flask-WTF library. ### Starlette-Login GitHub | Documentation User session management for Starlette. It handles the common tasks of logging in, logging out, and remembering your users' sessions over extended periods of time. ### Starsessions GitHub An alternate session support implementation with customizable storage backends. ### webargs-starlette GitHub Declarative request parsing and validation for Starlette, built on top of [webargs](https://github.com/marshmallow-code/webargs). Allows you to parse querystring, JSON, form, headers, and cookies using type annotations. ### DecoRouter GitHub FastAPI style routing for Starlette. Allows you to use decorators to generate routing tables. ### Starception GitHub Beautiful exception page for Starlette apps. ### Starlette-Admin GitHub | Documentation Simple and extensible admin interface framework. Built with [Tabler](https://tabler.io/) and [Datatables](https://datatables.net/), it allows you to quickly generate fully customizable admin interface for your models. You can export your data to many formats (*CSV*, *PDF*, *Excel*, etc), filter your data with complex query including `AND` and `OR` conditions, upload files, ... ### Vellox GitHub Serverless ASGI adapter for GCP Cloud Functions. ## Starlette Bridge GitHub | Documentation With the deprecation of `on_startup` and `on_shutdown`, Starlette Bridge makes sure you can still use the old ways of declaring events with a particularity that internally, in fact, creates the `lifespan` for you. This way backwards compatibility is assured for the existing packages out there while maintaining the integrity of the newly `lifespan` events of `Starlette`. ## Frameworks ### FastAPI GitHub | Documentation High performance, easy to learn, fast to code, ready for production web API framework. Inspired by **APIStar**'s previous server system with type declarations for route parameters, based on the OpenAPI specification version 3.0.0+ (with JSON Schema), powered by **Pydantic** for the data handling. ### Flama GitHub | Documentation Flama is a **data-science oriented framework** to rapidly build modern and robust **machine learning** (ML) APIs. The main aim of the framework is to make ridiculously simple the deployment of ML APIs. With Flama, data scientists can now quickly turn their ML models into asynchronous, auto-documented APIs with just a single line of code. All in just few seconds! Flama comes with an intuitive CLI, and provides an easy-to-learn philosophy to speed up the building of **highly performant** GraphQL, REST, and ML APIs. Besides, it comprises an ideal solution for the development of asynchronous and **production-ready** services, offering **automatic deployment** for ML models. ### Greppo GitHub | Documentation A Python framework for building geospatial dashboards and web-applications. Greppo is an open-source Python framework that makes it easy to build geospatial dashboards and web-applications. It provides a toolkit to quickly integrate data, algorithms, visualizations and UI for interactivity. It provides APIs to the update the variables in the backend, recompute the logic, and reflect the changes in the frontend (data mutation hook). ### Responder GitHub | Documentation Async web service framework. Some Features: flask-style route expression, yaml support, OpenAPI schema generation, background tasks, graphql. ### Starlette-apps Roll your own framework with a simple app system, like [Django-GDAPS](https://gdaps.readthedocs.io/en/latest/) or [CakePHP](https://cakephp.org/). GitHub ### Dark Star A simple framework to help minimise the code needed to get HTML to the browser. Changes your file paths into Starlette routes and puts your view code right next to your template. Includes support for [htmx](https://htmx.org) to help enhance your frontend. Docs GitHub ### Xpresso A flexible and extendable web framework built on top of Starlette, Pydantic and [di](https://github.com/adriangb/di). GitHub | Documentation ### Ellar GitHub | Documentation Ellar is an ASGI web framework for building fast, efficient and scalable RESTAPIs and server-side applications. It offers a high level of abstraction in building server-side applications and combines elements of OOP (Object Oriented Programming), and FP (Functional Programming) - Inspired by Nestjs. It is built on 3 core libraries **Starlette**, **Pydantic**, and **injector**. ### Apiman An extension to integrate Swagger/OpenAPI document easily for Starlette project and provide [SwaggerUI](http://swagger.io/swagger-ui/) and [RedocUI](https://rebilly.github.io/ReDoc/). GitHub ### Starlette-Babel Provides translations, localization, and timezone support via Babel integration. GitHub ### Starlette-StaticResources GitHub Allows mounting [package resources](https://docs.python.org/3/library/importlib.resources.html#module-importlib.resources) for static data, similar to [StaticFiles](staticfiles.md). ### Sentry GitHub | Documentation Sentry is a software error detection tool. It offers actionable insights for resolving performance issues and errors, allowing users to diagnose, fix, and optimize Python debugging. Additionally, it integrates seamlessly with Starlette for Python application development. Sentry's capabilities include error tracking, performance insights, contextual information, and alerts/notifications. ### Shiny GitHub | Documentation Leveraging Starlette and asyncio, Shiny allows developers to create effortless Python web applications using the power of reactive programming. Shiny eliminates the hassle of manual state management, automatically determining the best execution path for your app at runtime while simultaneously minimizing re-rendering. This means that Shiny can support everything from the simplest dashboard to full-featured web apps. ================================================ FILE: docs/threadpool.md ================================================ # Thread Pool Starlette uses a thread pool in several scenarios to avoid blocking the event loop: - When you create a synchronous endpoint using `def` instead of `async def` - When serving files with [`FileResponse`](responses.md#fileresponse) - When handling file uploads with [`UploadFile`](requests.md#request-files) - When running synchronous background tasks with [`BackgroundTask`](background.md) - And some other scenarios that may not be documented... Starlette will run your code in a thread pool to avoid blocking the event loop. This applies for endpoint functions and background tasks you create, but also for internal Starlette code. To be more precise, Starlette uses `anyio.to_thread.run_sync` to run the synchronous code. ## Concurrency Limitations The default thread pool size is only 40 _tokens_. This means that only 40 threads can run at the same time. This limit is shared with other libraries: for example FastAPI also uses `anyio` to run sync dependencies, which also uses up thread capacity. If you need to run more threads, you can increase the number of _tokens_: ```py import anyio.to_thread limiter = anyio.to_thread.current_default_thread_limiter() limiter.total_tokens = 100 ``` The above code will increase the number of _tokens_ to 100. Increasing the number of threads may have a performance and memory impact, so be careful when doing so. ================================================ FILE: docs/websockets.md ================================================ Starlette includes a `WebSocket` class that fulfils a similar role to the HTTP request, but that allows sending and receiving data on a websocket. ### WebSocket Signature: `WebSocket(scope, receive=None, send=None)` ```python from starlette.websockets import WebSocket async def app(scope, receive, send): websocket = WebSocket(scope=scope, receive=receive, send=send) await websocket.accept() await websocket.send_text('Hello, world!') await websocket.close() ``` WebSockets present a mapping interface, so you can use them in the same way as a `scope`. For instance: `websocket['path']` will return the ASGI path. #### URL The websocket URL is accessed as `websocket.url`. The property is actually a subclass of `str`, and also exposes all the components that can be parsed out of the URL. For example: `websocket.url.path`, `websocket.url.port`, `websocket.url.scheme`. #### Headers Headers are exposed as an immutable, case-insensitive, multi-dict. For example: `websocket.headers['sec-websocket-version']` #### Query Parameters Query parameters are exposed as an immutable multi-dict. For example: `websocket.query_params['search']` #### Path Parameters Router path parameters are exposed as a dictionary interface. For example: `websocket.path_params['username']` ### Accepting the connection * `await websocket.accept(subprotocol=None, headers=None)` ### Sending data * `await websocket.send_text(data)` * `await websocket.send_bytes(data)` * `await websocket.send_json(data)` JSON messages default to being sent over text data frames, from version 0.10.0 onwards. Use `websocket.send_json(data, mode="binary")` to send JSON over binary data frames. ### Receiving data * `await websocket.receive_text()` * `await websocket.receive_bytes()` * `await websocket.receive_json()` May raise `starlette.websockets.WebSocketDisconnect()`. JSON messages default to being received over text data frames, from version 0.10.0 onwards. Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary data frames. ### Iterating data * `websocket.iter_text()` * `websocket.iter_bytes()` * `websocket.iter_json()` Similar to `receive_text`, `receive_bytes`, and `receive_json` but returns an async iterator. ```python hl_lines="7-8" from starlette.websockets import WebSocket async def app(scope, receive, send): websocket = WebSocket(scope=scope, receive=receive, send=send) await websocket.accept() async for message in websocket.iter_text(): await websocket.send_text(f"Message text was: {message}") await websocket.close() ``` When `starlette.websockets.WebSocketDisconnect` is raised, the iterator will exit. ### Closing the connection * `await websocket.close(code=1000, reason=None)` ### Sending and receiving messages If you need to send or receive raw ASGI messages then you should use `websocket.send()` and `websocket.receive()` rather than using the raw `send` and `receive` callables. This will ensure that the websocket's state is kept correctly updated. * `await websocket.send(message)` * `await websocket.receive()` ### Send Denial Response If you call `websocket.close()` before calling `websocket.accept()` then the server will automatically send a HTTP 403 error to the client. If you want to send a different error response, you can use the `websocket.send_denial_response()` method. This will send the response and then close the connection. * `await websocket.send_denial_response(response)` This requires the ASGI server to support the WebSocket Denial Response extension. If it is not supported a `RuntimeError` will be raised. In the context of `Starlette`, you can also use the `HTTPException` to achieve the same result. ```python from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket def is_authorized(subprotocols: list[str]): if len(subprotocols) != 2: return False if subprotocols[0] != "Authorization": return False # Here we are hard coding the token, in a real application you would validate the token # against a database or an external service. if subprotocols[1] != "token": return False return True async def websocket_endpoint(websocket: WebSocket): subprotocols = websocket.scope["subprotocols"] if not is_authorized(subprotocols): raise HTTPException(status_code=401, detail="Unauthorized") await websocket.accept("Authorization") await websocket.send_text("Hello, world!") await websocket.close() app = Starlette(debug=True, routes=[WebSocketRoute("/ws", websocket_endpoint)]) ``` ================================================ FILE: mkdocs.yml ================================================ site_name: Starlette site_description: The little ASGI library that shines. site_url: https://starlette.dev repo_name: Kludex/starlette repo_url: https://github.com/Kludex/starlette edit_uri: edit/main/docs/ strict: true theme: name: "material" custom_dir: docs/overrides palette: - scheme: "default" media: "(prefers-color-scheme: light)" toggle: icon: "material/lightbulb" name: "Switch to dark mode" - scheme: "slate" media: "(prefers-color-scheme: dark)" primary: "blue" toggle: icon: "material/lightbulb-outline" name: "Switch to light mode" icon: repo: fontawesome/brands/github features: - content.code.copy - toc.follow nav: - Introduction: "index.md" - Features: - Applications: "applications.md" - Requests: "requests.md" - Responses: "responses.md" - WebSockets: "websockets.md" - Routing: "routing.md" - Endpoints: "endpoints.md" - Middleware: "middleware.md" - Static Files: "staticfiles.md" - Templates: "templates.md" - Database: "database.md" - GraphQL: "graphql.md" - Authentication: "authentication.md" - API Schemas: "schemas.md" - Lifespan: "lifespan.md" - Background Tasks: "background.md" - Server Push: "server-push.md" - Exceptions: "exceptions.md" - Configuration: "config.md" - Thread Pool: "threadpool.md" - Test Client: "testclient.md" - Release Notes: "release-notes.md" - Community: - Third Party Packages: "third-party-packages.md" - Contributing: "contributing.md" extra_css: - css/custom.css extra_javascript: - js/custom.js extra: analytics: provider: google property: G-Z37GTYBR6M social: - icon: fontawesome/brands/github-alt link: https://github.com/Kludex/starlette - icon: fontawesome/brands/discord link: https://discord.com/invite/RxKUF5JuHs - icon: fontawesome/brands/twitter link: https://x.com/marcelotryle - icon: fontawesome/brands/linkedin link: https://www.linkedin.com/in/marcelotryle - icon: fontawesome/solid/globe link: https://fastapiexpert.com markdown_extensions: - attr_list - admonition - pymdownx.highlight - pymdownx.superfences - pymdownx.details - pymdownx.tabbed: alternate_style: true - pymdownx.emoji: emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_generator: !!python/name:material.extensions.emoji.to_svg - pymdownx.tasklist: custom_checkbox: true watch: - starlette plugins: - search - mkdocstrings: handlers: python: options: docstring_section_style: list show_root_toc_entry: false members_order: source separate_signature: true filters: ["!^_"] docstring_options: ignore_init_summary: true merge_init_into_class: true parameter_headings: true show_signature_annotations: true show_source: false signature_crossrefs: true inventories: - url: https://docs.python.org/3/objects.inv ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [project] name = "starlette" dynamic = ["version"] description = "The little ASGI library that shines." readme = "README.md" license = "BSD-3-Clause" license-files = ["LICENSE.md"] requires-python = ">=3.10" authors = [ { name = "Tom Christie", email = "tom@tomchristie.com" } ] maintainers = [ { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, ] classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Web Environment", "Framework :: AnyIO", "Intended Audience :: Developers", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Topic :: Internet :: WWW/HTTP", ] dependencies = [ "anyio>=3.6.2,<5", "typing_extensions>=4.10.0; python_version < '3.13'", ] [project.optional-dependencies] full = [ "itsdangerous", "jinja2", "python-multipart>=0.0.18", "pyyaml", "httpx>=0.27.0,<0.29.0", ] [dependency-groups] dev = [ # We add starlette[full] so `uv sync` considers the extras. "starlette[full]", "coverage>=7.8.2", "importlib-metadata==8.7.1", "mypy==1.16.1", "ruff==0.15.4", "types-PyYAML==6.0.12.20250516", "pytest==9.0.2", "trio==0.33.0", # Check dist "twine==6.2.0", ] docs = [ "black==26.3.1", "mkdocstrings>=1.0.2", "mkdocstrings-python>=2.0.1", "zensical>=0.0.19", ] [tool.uv] default-groups = ["dev", "docs"] required-version = ">=0.8.6" [project.urls] Homepage = "https://github.com/Kludex/starlette" Documentation = "https://starlette.dev/" Changelog = "https://starlette.dev/release-notes/" Funding = "https://github.com/sponsors/Kludex" Source = "https://github.com/Kludex/starlette" [tool.hatch.version] path = "starlette/__init__.py" [tool.ruff] line-length = 120 [tool.ruff.lint] select = [ "E", # https://docs.astral.sh/ruff/rules/#error-e "F", # https://docs.astral.sh/ruff/rules/#pyflakes-f "I", # https://docs.astral.sh/ruff/rules/#isort-i "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf ] ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/ [tool.ruff.lint.isort] combine-as-imports = true [tool.mypy] strict = true [[tool.mypy.overrides]] module = "starlette.testclient.*" implicit_optional = true [tool.pytest.ini_options] addopts = "-rXs --strict-config --strict-markers" xfail_strict = true filterwarnings = [ # Turn warnings that aren't filtered into exceptions "error", "ignore: run_until_first_complete is deprecated and will be removed in a future version.:DeprecationWarning", "ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning", "ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning", "ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning", ] [tool.coverage.run] branch = true source_pkgs = ["starlette", "tests"] [tool.coverage.report] exclude_also = [ "@overload", "raise NotImplementedError", ] ================================================ FILE: scripts/README.md ================================================ # Development Scripts * `scripts/install` - Install dependencies in a virtual environment. * `scripts/test` - Run the test suite. * `scripts/lint` - Run the automated code linting/formatting tools. * `scripts/check` - Run the code linting, checking that it passes. * `scripts/coverage` - Check that code coverage is complete. * `scripts/build` - Build source and wheel packages. Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all). ================================================ FILE: scripts/build ================================================ #!/bin/sh -e set -x uv build uv run twine check dist/* uv run zensical build --clean ================================================ FILE: scripts/check ================================================ #!/bin/sh -e export SOURCE_FILES="starlette tests" set -x ./scripts/sync-version uv run ruff format --check --diff $SOURCE_FILES uv run mypy $SOURCE_FILES uv run ruff check $SOURCE_FILES ================================================ FILE: scripts/coverage ================================================ #!/bin/sh -e set -x uv run coverage report --show-missing --skip-covered --fail-under=100 ================================================ FILE: scripts/docs ================================================ #!/bin/sh -e set -x uv run zensical serve ================================================ FILE: scripts/install ================================================ #!/bin/sh -e set -x uv sync --frozen ================================================ FILE: scripts/lint ================================================ #!/bin/sh -e export SOURCE_FILES="starlette tests" set -x uv run ruff format $SOURCE_FILES uv run ruff check --fix $SOURCE_FILES ================================================ FILE: scripts/sync-version ================================================ #!/bin/sh -e SEMVER_REGEX="([0-9]+)\.([0-9]+)\.([0-9]+)(-([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*))?(\+[0-9A-Za-z-]+)?" CHANGELOG_VERSION=$(grep -o -E $SEMVER_REGEX docs/release-notes.md | head -1) VERSION=$(grep -o -E $SEMVER_REGEX starlette/__init__.py | head -1) if [ "$CHANGELOG_VERSION" != "$VERSION" ]; then echo "Version in changelog does not match version in starlette/__init__.py!" exit 1 fi ================================================ FILE: scripts/test ================================================ #!/bin/sh set -ex if [ -z $GITHUB_ACTIONS ]; then scripts/check fi uv run coverage run -m pytest $@ if [ -z $GITHUB_ACTIONS ]; then scripts/coverage fi ================================================ FILE: starlette/__init__.py ================================================ __version__ = "1.0.0rc1" ================================================ FILE: starlette/_exception_handler.py ================================================ from __future__ import annotations from typing import Any from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send from starlette.websockets import WebSocket ExceptionHandlers = dict[Any, ExceptionHandler] StatusHandlers = dict[int, ExceptionHandler] def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None: for cls in type(exc).__mro__: if cls in exc_handlers: return exc_handlers[cls] return None def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp: exception_handlers: ExceptionHandlers status_handlers: StatusHandlers try: exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"] except KeyError: exception_handlers, status_handlers = {}, {} async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": response_started = True await send(message) try: await app(scope, receive, sender) except Exception as exc: handler = None if isinstance(exc, HTTPException): handler = status_handlers.get(exc.status_code) if handler is None: handler = _lookup_exception_handler(exception_handlers, exc) if handler is None: raise exc if response_started: raise RuntimeError("Caught handled exception, but response already started.") from exc if is_async_callable(handler): response = await handler(conn, exc) else: response = await run_in_threadpool(handler, conn, exc) if response is not None: await response(scope, receive, sender) return wrapped_app ================================================ FILE: starlette/_utils.py ================================================ from __future__ import annotations import functools import sys from collections.abc import Awaitable, Callable, Generator from contextlib import AbstractAsyncContextManager, contextmanager from typing import Any, Generic, Protocol, TypeVar, overload from starlette.types import Scope if sys.version_info >= (3, 13): # pragma: no cover from inspect import iscoroutinefunction from typing import TypeIs else: # pragma: no cover from asyncio import iscoroutinefunction from typing_extensions import TypeIs has_exceptiongroups = True if sys.version_info < (3, 11): # pragma: no cover try: from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] except ImportError: has_exceptiongroups = False T = TypeVar("T") AwaitableCallable = Callable[..., Awaitable[T]] @overload def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ... @overload def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ... def is_async_callable(obj: Any) -> Any: while isinstance(obj, functools.partial): obj = obj.func return iscoroutinefunction(obj) or (callable(obj) and iscoroutinefunction(obj.__call__)) T_co = TypeVar("T_co", covariant=True) class AwaitableOrContextManager( Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co] ): ... # pragma: no branch class SupportsAsyncClose(Protocol): async def close(self) -> None: ... # pragma: no cover SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False) class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]): __slots__ = ("aw", "entered") def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None: self.aw = aw def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]: return self.aw.__await__() async def __aenter__(self) -> SupportsAsyncCloseType: self.entered = await self.aw return self.entered async def __aexit__(self, *args: Any) -> None | bool: await self.entered.close() return None @contextmanager def collapse_excgroups() -> Generator[None, None, None]: try: yield except BaseException as exc: if has_exceptiongroups: # pragma: no cover while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: exc = exc.exceptions[0] raise exc def get_route_path(scope: Scope) -> str: path: str = scope["path"] root_path = scope.get("root_path", "") if not root_path: return path if not path.startswith(root_path): return path if path == root_path: return "" if path[len(root_path)] == "/": return path[len(root_path) :] return path ================================================ FILE: starlette/applications.py ================================================ from __future__ import annotations from collections.abc import Awaitable, Callable, Mapping, Sequence from typing import Any, ParamSpec, TypeVar from starlette.datastructures import State, URLPath from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send AppType = TypeVar("AppType", bound="Starlette") P = ParamSpec("P") class Starlette: """Creates an Starlette application.""" def __init__( self: AppType, debug: bool = False, routes: Sequence[BaseRoute] | None = None, middleware: Sequence[Middleware] | None = None, exception_handlers: Mapping[Any, ExceptionHandler] | None = None, lifespan: Lifespan[AppType] | None = None, ) -> None: """Initializes the application. Parameters: debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. middleware: A list of middleware to run for every request. A starlette application will always automatically include two middleware classes. `ServerErrorMiddleware` is added as the very outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled exception cases occurring in the routing or endpoints. exception_handlers: A mapping of either integer status codes, or exception class types onto callables which handle the exceptions. Exception handler callables should be of the form `handler(request, exc) -> response` and may be either standard functions, or async functions. lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or the other, not both. """ self.debug = debug self.state = State() self.router = Router(routes, lifespan=lifespan) self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers) self.user_middleware = [] if middleware is None else list(middleware) self.middleware_stack: ASGIApp | None = None def build_middleware_stack(self) -> ASGIApp: debug = self.debug error_handler = None exception_handlers: dict[Any, ExceptionHandler] = {} for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value else: exception_handlers[key] = value middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + self.user_middleware + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)] ) app = self.router for cls, args, kwargs in reversed(middleware): app = cls(app, *args, **kwargs) return app @property def routes(self) -> list[BaseRoute]: return self.router.routes def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: return self.router.url_path_for(name, **path_params) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self if self.middleware_stack is None: self.middleware_stack = self.build_middleware_stack() await self.middleware_stack(scope, receive, send) def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: self.router.mount(path, app=app, name=name) # pragma: no cover def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: self.router.host(host, app=app, name=name) # pragma: no cover def add_middleware(self, middleware_class: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None: if self.middleware_stack is not None: # pragma: no cover raise RuntimeError("Cannot add middleware after an application has started") self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs)) def add_exception_handler( self, exc_class_or_status_code: int | type[Exception], handler: ExceptionHandler, ) -> None: # pragma: no cover self.exception_handlers[exc_class_or_status_code] = handler def add_route( self, path: str, route: Callable[[Request], Awaitable[Response] | Response], methods: list[str] | None = None, name: str | None = None, include_in_schema: bool = True, ) -> None: # pragma: no cover self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema) ================================================ FILE: starlette/authentication.py ================================================ from __future__ import annotations import functools import inspect from collections.abc import Callable, Sequence from typing import Any, ParamSpec from urllib.parse import urlencode from starlette._utils import is_async_callable from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse from starlette.websockets import WebSocket _P = ParamSpec("_P") def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool: for scope in scopes: if scope not in conn.auth.scopes: return False return True def requires( scopes: str | Sequence[str], status_code: int = 403, redirect: str | None = None, ) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator( func: Callable[_P, Any], ) -> Callable[_P, Any]: sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": type_ = parameter.name break else: raise Exception(f'No "request" or "websocket" argument on function "{func}"') if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): await websocket.close() else: await func(*args, **kwargs) return websocket_wrapper elif is_async_callable(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) return async_wrapper else: # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return func(*args, **kwargs) return sync_wrapper return decorator class AuthenticationError(Exception): pass class AuthenticationBackend: async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: raise NotImplementedError() # pragma: no cover class AuthCredentials: def __init__(self, scopes: Sequence[str] | None = None): self.scopes = [] if scopes is None else list(scopes) class BaseUser: @property def is_authenticated(self) -> bool: raise NotImplementedError() # pragma: no cover @property def display_name(self) -> str: raise NotImplementedError() # pragma: no cover @property def identity(self) -> str: raise NotImplementedError() # pragma: no cover class SimpleUser(BaseUser): def __init__(self, username: str) -> None: self.username = username @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username class UnauthenticatedUser(BaseUser): @property def is_authenticated(self) -> bool: return False @property def display_name(self) -> str: return "" ================================================ FILE: starlette/background.py ================================================ from __future__ import annotations from collections.abc import Callable, Sequence from typing import Any, ParamSpec from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool P = ParamSpec("P") class BackgroundTask: def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: self.func = func self.args = args self.kwargs = kwargs self.is_async = is_async_callable(func) async def __call__(self) -> None: if self.is_async: await self.func(*self.args, **self.kwargs) else: await run_in_threadpool(self.func, *self.args, **self.kwargs) class BackgroundTasks(BackgroundTask): def __init__(self, tasks: Sequence[BackgroundTask] | None = None): self.tasks = list(tasks) if tasks else [] def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: task = BackgroundTask(func, *args, **kwargs) self.tasks.append(task) async def __call__(self) -> None: for task in self.tasks: await task() ================================================ FILE: starlette/concurrency.py ================================================ from __future__ import annotations import functools import warnings from collections.abc import AsyncIterator, Callable, Coroutine, Iterable, Iterator from typing import ParamSpec, TypeVar import anyio.to_thread P = ParamSpec("P") T = TypeVar("T") async def run_until_first_complete(*args: tuple[Callable, dict]) -> None: # type: ignore[type-arg] warnings.warn( "run_until_first_complete is deprecated and will be removed in a future version.", DeprecationWarning, ) async with anyio.create_task_group() as task_group: async def run(func: Callable[[], Coroutine]) -> None: # type: ignore[type-arg] await func() task_group.cancel_scope.cancel() for func, kwargs in args: task_group.start_soon(run, functools.partial(func, **kwargs)) async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: func = functools.partial(func, *args, **kwargs) return await anyio.to_thread.run_sync(func) class _StopIteration(Exception): pass def _next(iterator: Iterator[T]) -> T: # We can't raise `StopIteration` from within the threadpool iterator # and catch it outside that context, so we coerce them into a different # exception type. try: return next(iterator) except StopIteration: raise _StopIteration async def iterate_in_threadpool( iterator: Iterable[T], ) -> AsyncIterator[T]: as_iterator = iter(iterator) while True: try: yield await anyio.to_thread.run_sync(_next, as_iterator) except _StopIteration: break ================================================ FILE: starlette/config.py ================================================ from __future__ import annotations import os import warnings from collections.abc import Callable, Iterator, Mapping, MutableMapping from pathlib import Path from typing import Any, TypeVar, overload class undefined: pass class EnvironError(Exception): pass class Environ(MutableMapping[str, str]): def __init__(self, environ: MutableMapping[str, str] = os.environ): self._environ = environ self._has_been_read: set[str] = set() def __getitem__(self, key: str) -> str: self._has_been_read.add(key) return self._environ.__getitem__(key) def __setitem__(self, key: str, value: str) -> None: if key in self._has_been_read: raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.") self._environ.__setitem__(key, value) def __delitem__(self, key: str) -> None: if key in self._has_been_read: raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.") self._environ.__delitem__(key) def __iter__(self) -> Iterator[str]: return iter(self._environ) def __len__(self) -> int: return len(self._environ) environ = Environ() T = TypeVar("T") class Config: def __init__( self, env_file: str | Path | None = None, environ: Mapping[str, str] = environ, env_prefix: str = "", encoding: str = "utf-8", ) -> None: self.environ = environ self.env_prefix = env_prefix self.file_values: dict[str, str] = {} if env_file is not None: if not os.path.isfile(env_file): warnings.warn(f"Config file '{env_file}' not found.") else: self.file_values = self._read_file(env_file, encoding) @overload def __call__(self, key: str, *, default: None) -> str | None: ... @overload def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ... @overload def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ... @overload def __call__( self, key: str, cast: Callable[[Any], T] = ..., default: Any = ..., ) -> T: ... @overload def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ... def __call__( self, key: str, cast: Callable[[Any], Any] | None = None, default: Any = undefined, ) -> Any: return self.get(key, cast, default) def get( self, key: str, cast: Callable[[Any], Any] | None = None, default: Any = undefined, ) -> Any: key = self.env_prefix + key if key in self.environ: value = self.environ[key] return self._perform_cast(key, value, cast) if key in self.file_values: value = self.file_values[key] return self._perform_cast(key, value, cast) if default is not undefined: return self._perform_cast(key, default, cast) raise KeyError(f"Config '{key}' is missing, and has no default.") def _read_file(self, file_name: str | Path, encoding: str) -> dict[str, str]: file_values: dict[str, str] = {} with open(file_name, encoding=encoding) as input_file: for line in input_file.readlines(): line = line.strip() if "=" in line and not line.startswith("#"): key, value = line.split("=", 1) key = key.strip() value = value.strip().strip("\"'") file_values[key] = value return file_values def _perform_cast( self, key: str, value: Any, cast: Callable[[Any], Any] | None = None, ) -> Any: if cast is None or value is None: return value elif cast is bool and isinstance(value, str): mapping = {"true": True, "1": True, "false": False, "0": False} value = value.lower() if value not in mapping: raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.") return mapping[value] try: return cast(value) except (TypeError, ValueError): raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.") ================================================ FILE: starlette/convertors.py ================================================ from __future__ import annotations import math import uuid from typing import Any, ClassVar, Generic, TypeVar T = TypeVar("T") class Convertor(Generic[T]): regex: ClassVar[str] = "" def convert(self, value: str) -> T: raise NotImplementedError() # pragma: no cover def to_string(self, value: T) -> str: raise NotImplementedError() # pragma: no cover class StringConvertor(Convertor[str]): regex = "[^/]+" def convert(self, value: str) -> str: return value def to_string(self, value: str) -> str: value = str(value) assert "/" not in value, "May not contain path separators" assert value, "Must not be empty" return value class PathConvertor(Convertor[str]): regex = ".*" def convert(self, value: str) -> str: return str(value) def to_string(self, value: str) -> str: return str(value) class IntegerConvertor(Convertor[int]): regex = "[0-9]+" def convert(self, value: str) -> int: return int(value) def to_string(self, value: int) -> str: value = int(value) assert value >= 0, "Negative integers are not supported" return str(value) class FloatConvertor(Convertor[float]): regex = r"[0-9]+(\.[0-9]+)?" def convert(self, value: str) -> float: return float(value) def to_string(self, value: float) -> str: value = float(value) assert value >= 0.0, "Negative floats are not supported" assert not math.isnan(value), "NaN values are not supported" assert not math.isinf(value), "Infinite values are not supported" return ("%0.20f" % value).rstrip("0").rstrip(".") class UUIDConvertor(Convertor[uuid.UUID]): regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}" def convert(self, value: str) -> uuid.UUID: return uuid.UUID(value) def to_string(self, value: uuid.UUID) -> str: return str(value) CONVERTOR_TYPES: dict[str, Convertor[Any]] = { "str": StringConvertor(), "path": PathConvertor(), "int": IntegerConvertor(), "float": FloatConvertor(), "uuid": UUIDConvertor(), } def register_url_convertor(key: str, convertor: Convertor[Any]) -> None: CONVERTOR_TYPES[key] = convertor ================================================ FILE: starlette/datastructures.py ================================================ from __future__ import annotations from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView from shlex import shlex from typing import ( Any, BinaryIO, NamedTuple, TypeVar, cast, ) from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit from starlette.concurrency import run_in_threadpool from starlette.types import Scope class Address(NamedTuple): host: str port: int _KeyType = TypeVar("_KeyType") # Mapping keys are invariant but their values are covariant since # you can only read them # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` _CovariantValueType = TypeVar("_CovariantValueType", covariant=True) class URL: def __init__( self, url: str = "", scope: Scope | None = None, **components: Any, ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' assert not components, 'Cannot set both "scope" and "**components".' scheme = scope.get("scheme", "http") server = scope.get("server", None) path = scope["path"] query_string = scope.get("query_string", b"") host_header = None for key, value in scope["headers"]: if key == b"host": host_header = value.decode("latin-1") break if host_header is not None: url = f"{scheme}://{host_header}{path}" elif server is None: url = path else: host, port = server default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] if port == default_port: url = f"{scheme}://{host}{path}" else: url = f"{scheme}://{host}:{port}{path}" if query_string: url += "?" + query_string.decode() elif components: assert not url, 'Cannot set both "url" and "**components".' url = URL("").replace(**components).components.geturl() self._url = url @property def components(self) -> SplitResult: if not hasattr(self, "_components"): self._components = urlsplit(self._url) return self._components @property def scheme(self) -> str: return self.components.scheme @property def netloc(self) -> str: return self.components.netloc @property def path(self) -> str: return self.components.path @property def query(self) -> str: return self.components.query @property def fragment(self) -> str: return self.components.fragment @property def username(self) -> None | str: return self.components.username @property def password(self) -> None | str: return self.components.password @property def hostname(self) -> None | str: return self.components.hostname @property def port(self) -> int | None: return self.components.port @property def is_secure(self) -> bool: return self.scheme in ("https", "wss") def replace(self, **kwargs: Any) -> URL: if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: hostname = kwargs.pop("hostname", None) port = kwargs.pop("port", self.port) username = kwargs.pop("username", self.username) password = kwargs.pop("password", self.password) if hostname is None: netloc = self.netloc _, _, hostname = netloc.rpartition("@") if hostname[-1] != "]": hostname = hostname.rsplit(":", 1)[0] netloc = hostname if port is not None: netloc += f":{port}" if username is not None: userpass = username if password is not None: userpass += f":{password}" netloc = f"{userpass}@{netloc}" kwargs["netloc"] = netloc components = self.components._replace(**kwargs) return self.__class__(components.geturl()) def include_query_params(self, **kwargs: Any) -> URL: params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) params.update({str(key): str(value) for key, value in kwargs.items()}) query = urlencode(params.multi_items()) return self.replace(query=query) def replace_query_params(self, **kwargs: Any) -> URL: query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) def remove_query_params(self, keys: str | Sequence[str]) -> URL: if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) for key in keys: params.pop(key, None) query = urlencode(params.multi_items()) return self.replace(query=query) def __eq__(self, other: Any) -> bool: return str(self) == str(other) def __str__(self) -> str: return self._url def __repr__(self) -> str: url = str(self) if self.password: url = str(self.replace(password="********")) return f"{self.__class__.__name__}({repr(url)})" class URLPath(str): """ A URL path string that may also hold an associated protocol and/or host. Used by the routing to return `url_path_for` matches. """ def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: assert protocol in ("http", "websocket", "") return str.__new__(cls, path) def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol self.host = host def make_absolute_url(self, base_url: str | URL) -> URL: if isinstance(base_url, str): base_url = URL(base_url) if self.protocol: scheme = { "http": {True: "https", False: "http"}, "websocket": {True: "wss", False: "ws"}, }[self.protocol][base_url.is_secure] else: scheme = base_url.scheme netloc = self.host or base_url.netloc path = base_url.path.rstrip("/") + str(self) return URL(scheme=scheme, netloc=netloc, path=path) class Secret: """ Holds a string value that should not be revealed in tracebacks etc. You should cast the value to `str` at the point it is required. """ def __init__(self, value: str): self._value = value def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}('**********')" def __str__(self) -> str: return self._value def __bool__(self) -> bool: return bool(self._value) class CommaSeparatedStrings(Sequence[str]): def __init__(self, value: str | Sequence[str]): if isinstance(value, str): splitter = shlex(value, posix=True) splitter.whitespace = "," splitter.whitespace_split = True self._items = [item.strip() for item in splitter] else: self._items = list(value) def __len__(self) -> int: return len(self._items) def __getitem__(self, index: int | slice) -> Any: return self._items[index] def __iter__(self) -> Iterator[str]: return iter(self._items) def __repr__(self) -> str: class_name = self.__class__.__name__ items = [item for item in self] return f"{class_name}({items!r})" def __str__(self) -> str: return ", ".join(repr(item) for item in self) class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]): _dict: dict[_KeyType, _CovariantValueType] def __init__( self, *args: ImmutableMultiDict[_KeyType, _CovariantValueType] | Mapping[_KeyType, _CovariantValueType] | Iterable[tuple[_KeyType, _CovariantValueType]], **kwargs: Any, ) -> None: assert len(args) < 2, "Too many arguments." value: Any = args[0] if args else [] if kwargs: value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() if not value: _items: list[tuple[Any, Any]] = [] elif hasattr(value, "multi_items"): value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) _items = list(value.multi_items()) elif hasattr(value, "items"): value = cast(Mapping[_KeyType, _CovariantValueType], value) _items = list(value.items()) else: value = cast("list[tuple[Any, Any]]", value) _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items def getlist(self, key: Any) -> list[_CovariantValueType]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self) -> KeysView[_KeyType]: return self._dict.keys() def values(self) -> ValuesView[_CovariantValueType]: return self._dict.values() def items(self) -> ItemsView[_KeyType, _CovariantValueType]: return self._dict.items() def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: return list(self._list) def __getitem__(self, key: _KeyType) -> _CovariantValueType: return self._dict[key] def __contains__(self, key: Any) -> bool: return key in self._dict def __iter__(self) -> Iterator[_KeyType]: return iter(self.keys()) def __len__(self) -> int: return len(self._dict) def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ items = self.multi_items() return f"{class_name}({items!r})" class MultiDict(ImmutableMultiDict[Any, Any]): def __setitem__(self, key: Any, value: Any) -> None: self.setlist(key, [value]) def __delitem__(self, key: Any) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] def pop(self, key: Any, default: Any = None) -> Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) def popitem(self) -> tuple[Any, Any]: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value def poplist(self, key: Any) -> list[Any]: values = [v for k, v in self._list if k == key] self.pop(key) return values def clear(self) -> None: self._dict.clear() self._list.clear() def setdefault(self, key: Any, default: Any = None) -> Any: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] def setlist(self, key: Any, values: list[Any]) -> None: if not values: self.pop(key, None) else: existing_items = [(k, v) for (k, v) in self._list if k != key] self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] def append(self, key: Any, value: Any) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]], **kwargs: Any, ) -> None: value = MultiDict(*args, **kwargs) existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] self._list = existing_items + value.multi_items() self._dict.update(value) class QueryParams(ImmutableMultiDict[str, str]): """ An immutable multidict. """ def __init__( self, *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes, **kwargs: Any, ) -> None: assert len(args) < 2, "Too many arguments." value = args[0] if args else [] if isinstance(value, str): super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) elif isinstance(value, bytes): super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) else: super().__init__(*args, **kwargs) # type: ignore[arg-type] self._list = [(str(k), str(v)) for k, v in self._list] self._dict = {str(k): str(v) for k, v in self._dict.items()} def __str__(self) -> str: return urlencode(self._list) def __repr__(self) -> str: class_name = self.__class__.__name__ query_string = str(self) return f"{class_name}({query_string!r})" class UploadFile: """ An uploaded file included as part of the request data. """ def __init__( self, file: BinaryIO, *, size: int | None = None, filename: str | None = None, headers: Headers | None = None, ) -> None: self.filename = filename self.file = file self.size = size self.headers = headers or Headers() # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ self._max_mem_size = getattr(self.file, "_max_size", 0) @property def content_type(self) -> str | None: return self.headers.get("content-type", None) @property def _in_memory(self) -> bool: # check for SpooledTemporaryFile._rolled rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk def _will_roll(self, size_to_add: int) -> bool: # If we're not in_memory then we will always roll if not self._in_memory: return True # Check for SpooledTemporaryFile._max_size future_size = self.file.tell() + size_to_add return bool(future_size > self._max_mem_size) if self._max_mem_size else False async def write(self, data: bytes) -> None: new_data_len = len(data) if self.size is not None: self.size += new_data_len if self._will_roll(new_data_len): await run_in_threadpool(self.file.write, data) else: self.file.write(data) async def read(self, size: int = -1) -> bytes: if self._in_memory: return self.file.read(size) return await run_in_threadpool(self.file.read, size) async def seek(self, offset: int) -> None: if self._in_memory: self.file.seek(offset) else: await run_in_threadpool(self.file.seek, offset) async def close(self) -> None: if self._in_memory: self.file.close() else: await run_in_threadpool(self.file.close) def __repr__(self) -> str: return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" class FormData(ImmutableMultiDict[str, UploadFile | str]): """ An immutable multidict, containing both file uploads and text input. """ def __init__( self, *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], **kwargs: str | UploadFile, ) -> None: super().__init__(*args, **kwargs) async def close(self) -> None: for key, value in self.multi_items(): if isinstance(value, UploadFile): await value.close() class Headers(Mapping[str, str]): """ An immutable, case-insensitive multidict. """ def __init__( self, headers: Mapping[str, str] | None = None, raw: list[tuple[bytes, bytes]] | None = None, scope: MutableMapping[str, Any] | None = None, ) -> None: self._list: list[tuple[bytes, bytes]] = [] if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] elif raw is not None: assert scope is None, 'Cannot set both "raw" and "scope".' self._list = raw elif scope is not None: # scope["headers"] isn't necessarily a list # it might be a tuple or other iterable self._list = scope["headers"] = list(scope["headers"]) @property def raw(self) -> list[tuple[bytes, bytes]]: return list(self._list) def keys(self) -> list[str]: # type: ignore[override] return [key.decode("latin-1") for key, value in self._list] def values(self) -> list[str]: # type: ignore[override] return [value.decode("latin-1") for key, value in self._list] def items(self) -> list[tuple[str, str]]: # type: ignore[override] return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] def getlist(self, key: str) -> list[str]: get_header_key = key.lower().encode("latin-1") return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] def mutablecopy(self) -> MutableHeaders: return MutableHeaders(raw=self._list[:]) def __getitem__(self, key: str) -> str: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return header_value.decode("latin-1") raise KeyError(key) def __contains__(self, key: Any) -> bool: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return True return False def __iter__(self) -> Iterator[Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._list) def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ as_dict = dict(self.items()) if len(as_dict) == len(self): return f"{class_name}({as_dict!r})" return f"{class_name}(raw={self.raw!r})" class MutableHeaders(Headers): def __setitem__(self, key: str, value: str) -> None: """ Set the header `key` to `value`, removing any duplicate entries. Retains insertion order. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") found_indexes: list[int] = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: found_indexes.append(idx) for idx in reversed(found_indexes[1:]): del self._list[idx] if found_indexes: idx = found_indexes[0] self._list[idx] = (set_key, set_value) else: self._list.append((set_key, set_value)) def __delitem__(self, key: str) -> None: """ Remove the header `key`. """ del_key = key.lower().encode("latin-1") pop_indexes: list[int] = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == del_key: pop_indexes.append(idx) for idx in reversed(pop_indexes): del self._list[idx] def __ior__(self, other: Mapping[str, str]) -> MutableHeaders: if not isinstance(other, Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") self.update(other) return self def __or__(self, other: Mapping[str, str]) -> MutableHeaders: if not isinstance(other, Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") new = self.mutablecopy() new.update(other) return new @property def raw(self) -> list[tuple[bytes, bytes]]: return self._list def setdefault(self, key: str, value: str) -> str: """ If the header `key` does not exist, then set it to `value`. Returns the header value. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: return item_value.decode("latin-1") self._list.append((set_key, set_value)) return value def update(self, other: Mapping[str, str]) -> None: for key, val in other.items(): self[key] = val def append(self, key: str, value: str) -> None: """ Append a header, preserving any duplicate entries. """ append_key = key.lower().encode("latin-1") append_value = value.encode("latin-1") self._list.append((append_key, append_value)) def add_vary_header(self, vary: str) -> None: existing = self.get("vary") if existing is not None: vary = ", ".join([existing, vary]) self["vary"] = vary class State: """ An object that can be used to store arbitrary state. Used for `request.state` and `app.state`. """ _state: dict[str, Any] def __init__(self, state: dict[str, Any] | None = None): if state is None: state = {} super().__setattr__("_state", state) def __setattr__(self, key: Any, value: Any) -> None: self._state[key] = value def __getattr__(self, key: Any) -> Any: try: return self._state[key] except KeyError: message = "'{}' object has no attribute '{}'" raise AttributeError(message.format(self.__class__.__name__, key)) def __delattr__(self, key: Any) -> None: del self._state[key] def __getitem__(self, key: str) -> Any: return self._state[key] def __setitem__(self, key: str, value: Any) -> None: self._state[key] = value def __delitem__(self, key: str) -> None: del self._state[key] def __iter__(self) -> Iterator[str]: return iter(self._state) def __len__(self) -> int: return len(self._state) ================================================ FILE: starlette/endpoints.py ================================================ from __future__ import annotations import json from collections.abc import Callable, Generator from typing import Any, Literal from starlette import status from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket class HTTPEndpoint: def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "http" self.scope = scope self.receive = receive self.send = send self._allowed_methods = [ method for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") if getattr(self, method.lower(), None) is not None ] def __await__(self) -> Generator[Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: request = Request(self.scope, receive=self.receive) handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed) is_async = is_async_callable(handler) if is_async: response = await handler(request) else: response = await run_in_threadpool(handler, request) await response(self.scope, self.receive, self.send) async def method_not_allowed(self, request: Request) -> Response: # If we're running inside a starlette application then raise an # exception, so that the configurable exception handler can deal with # returning the response. For plain ASGI apps, just return the response. headers = {"Allow": ", ".join(self._allowed_methods)} if "app" in self.scope: raise HTTPException(status_code=405, headers=headers) return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) class WebSocketEndpoint: encoding: Literal["text", "bytes", "json"] | None = None def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "websocket" self.scope = scope self.receive = receive self.send = send def __await__(self) -> Generator[Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: websocket = WebSocket(self.scope, receive=self.receive, send=self.send) await self.on_connect(websocket) close_code = status.WS_1000_NORMAL_CLOSURE try: while True: message = await websocket.receive() if message["type"] == "websocket.receive": data = await self.decode(websocket, message) await self.on_receive(websocket, data) elif message["type"] == "websocket.disconnect": # pragma: no branch close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) break except Exception as exc: close_code = status.WS_1011_INTERNAL_ERROR raise exc finally: await self.on_disconnect(websocket, close_code) async def decode(self, websocket: WebSocket, message: Message) -> Any: if self.encoding == "text": if "text" not in message: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Expected text websocket messages, but got bytes") return message["text"] elif self.encoding == "bytes": if "bytes" not in message: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Expected bytes websocket messages, but got text") return message["bytes"] elif self.encoding == "json": if message.get("text") is not None: text = message["text"] else: text = message["bytes"].decode("utf-8") try: return json.loads(text) except json.decoder.JSONDecodeError: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) raise RuntimeError("Malformed JSON data received.") assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}" return message["text"] if message.get("text") else message["bytes"] async def on_connect(self, websocket: WebSocket) -> None: """Override to handle an incoming websocket connection""" await websocket.accept() async def on_receive(self, websocket: WebSocket, data: Any) -> None: """Override to handle an incoming websocket message""" async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: """Override to handle a disconnecting websocket""" ================================================ FILE: starlette/exceptions.py ================================================ from __future__ import annotations import http from collections.abc import Mapping class HTTPException(Exception): def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None: if detail is None: detail = http.HTTPStatus(status_code).phrase self.status_code = status_code self.detail = detail self.headers = headers def __str__(self) -> str: return f"{self.status_code}: {self.detail}" def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" class WebSocketException(Exception): def __init__(self, code: int, reason: str | None = None) -> None: self.code = code self.reason = reason or "" def __str__(self) -> str: return f"{self.code}: {self.reason}" def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}(code={self.code!r}, reason={self.reason!r})" ================================================ FILE: starlette/formparsers.py ================================================ from __future__ import annotations from collections.abc import AsyncGenerator from dataclasses import dataclass, field from enum import Enum from tempfile import SpooledTemporaryFile from typing import TYPE_CHECKING from urllib.parse import unquote_plus from starlette.datastructures import FormData, Headers, UploadFile if TYPE_CHECKING: import python_multipart as multipart from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header else: try: try: import python_multipart as multipart from python_multipart.multipart import parse_options_header except ModuleNotFoundError: # pragma: no cover import multipart from multipart.multipart import parse_options_header except ModuleNotFoundError: # pragma: no cover multipart = None parse_options_header = None class FormMessage(Enum): FIELD_START = 1 FIELD_NAME = 2 FIELD_DATA = 3 FIELD_END = 4 END = 5 @dataclass class MultipartPart: content_disposition: bytes | None = None field_name: str = "" data: bytearray = field(default_factory=bytearray) file: UploadFile | None = None item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: try: return src.decode(codec) except (UnicodeDecodeError, LookupError): return src.decode("latin-1") class MultiPartException(Exception): def __init__(self, message: str) -> None: self.message = message class FormParser: def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None: assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.messages: list[tuple[FormMessage, bytes]] = [] def on_field_start(self) -> None: message = (FormMessage.FIELD_START, b"") self.messages.append(message) def on_field_name(self, data: bytes, start: int, end: int) -> None: message = (FormMessage.FIELD_NAME, data[start:end]) self.messages.append(message) def on_field_data(self, data: bytes, start: int, end: int) -> None: message = (FormMessage.FIELD_DATA, data[start:end]) self.messages.append(message) def on_field_end(self) -> None: message = (FormMessage.FIELD_END, b"") self.messages.append(message) def on_end(self) -> None: message = (FormMessage.END, b"") self.messages.append(message) async def parse(self) -> FormData: # Callbacks dictionary. callbacks: QuerystringCallbacks = { "on_field_start": self.on_field_start, "on_field_name": self.on_field_name, "on_field_data": self.on_field_data, "on_field_end": self.on_field_end, "on_end": self.on_end, } # Create the parser. parser = multipart.QuerystringParser(callbacks) field_name = bytearray() field_value = bytearray() items: list[tuple[str, str | UploadFile]] = [] # Feed the parser with data from the request. async for chunk in self.stream: if chunk: parser.write(chunk) else: parser.finalize() messages = list(self.messages) self.messages.clear() for message_type, message_bytes in messages: if message_type == FormMessage.FIELD_START: field_name = bytearray() field_value = bytearray() elif message_type == FormMessage.FIELD_NAME: field_name.extend(message_bytes) elif message_type == FormMessage.FIELD_DATA: field_value.extend(message_bytes) elif message_type == FormMessage.FIELD_END: name = unquote_plus(field_name.decode("latin-1")) value = unquote_plus(field_value.decode("latin-1")) items.append((name, value)) return FormData(items) class MultiPartParser: spool_max_size = 1024 * 1024 # 1MB """The maximum size of the spooled temporary file used to store file data.""" max_part_size = 1024 * 1024 # 1MB """The maximum size of a part in the multipart request.""" def __init__( self, headers: Headers, stream: AsyncGenerator[bytes, None], *, max_files: int | float = 1000, max_fields: int | float = 1000, max_part_size: int = 1024 * 1024, # 1MB ) -> None: assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.max_files = max_files self.max_fields = max_fields self.items: list[tuple[str, str | UploadFile]] = [] self._current_files = 0 self._current_fields = 0 self._current_partial_header_name: bytes = b"" self._current_partial_header_value: bytes = b"" self._current_part = MultipartPart() self._charset = "" self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] self._file_parts_to_finish: list[MultipartPart] = [] self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] self.max_part_size = max_part_size def on_part_begin(self) -> None: self._current_part = MultipartPart() def on_part_data(self, data: bytes, start: int, end: int) -> None: message_bytes = data[start:end] if self._current_part.file is None: if len(self._current_part.data) + len(message_bytes) > self.max_part_size: raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") self._current_part.data.extend(message_bytes) else: self._file_parts_to_write.append((self._current_part, message_bytes)) def on_part_end(self) -> None: if self._current_part.file is None: self.items.append( ( self._current_part.field_name, _user_safe_decode(self._current_part.data, self._charset), ) ) else: self._file_parts_to_finish.append(self._current_part) # The file can be added to the items right now even though it's not # finished yet, because it will be finished in the `parse()` method, before # self.items is used in the return value. self.items.append((self._current_part.field_name, self._current_part.file)) def on_header_field(self, data: bytes, start: int, end: int) -> None: self._current_partial_header_name += data[start:end] def on_header_value(self, data: bytes, start: int, end: int) -> None: self._current_partial_header_value += data[start:end] def on_header_end(self) -> None: field = self._current_partial_header_name.lower() if field == b"content-disposition": self._current_part.content_disposition = self._current_partial_header_value self._current_part.item_headers.append((field, self._current_partial_header_value)) self._current_partial_header_name = b"" self._current_partial_header_value = b"" def on_headers_finished(self) -> None: disposition, options = parse_options_header(self._current_part.content_disposition) try: self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) except KeyError: raise MultiPartException('The Content-Disposition header field "name" must be provided.') if b"filename" in options: self._current_files += 1 if self._current_files > self.max_files: raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") filename = _user_safe_decode(options[b"filename"], self._charset) tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) self._files_to_close_on_error.append(tempfile) self._current_part.file = UploadFile( file=tempfile, # type: ignore[arg-type] size=0, filename=filename, headers=Headers(raw=self._current_part.item_headers), ) else: self._current_fields += 1 if self._current_fields > self.max_fields: raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") self._current_part.file = None def on_end(self) -> None: pass async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. _, params = parse_options_header(self.headers["Content-Type"]) charset = params.get(b"charset", "utf-8") if isinstance(charset, bytes): charset = charset.decode("latin-1") self._charset = charset try: boundary = params[b"boundary"] except KeyError: raise MultiPartException("Missing boundary in multipart.") # Callbacks dictionary. callbacks: MultipartCallbacks = { "on_part_begin": self.on_part_begin, "on_part_data": self.on_part_data, "on_part_end": self.on_part_end, "on_header_field": self.on_header_field, "on_header_value": self.on_header_value, "on_header_end": self.on_header_end, "on_headers_finished": self.on_headers_finished, "on_end": self.on_end, } # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) try: # Feed the parser with data from the request. async for chunk in self.stream: parser.write(chunk) # Write file data, it needs to use await with the UploadFile methods # that call the corresponding file methods *in a threadpool*, # otherwise, if they were called directly in the callback methods above # (regular, non-async functions), that would block the event loop in # the main thread. for part, data in self._file_parts_to_write: assert part.file # for type checkers await part.file.write(data) for part in self._file_parts_to_finish: assert part.file # for type checkers await part.file.seek(0) self._file_parts_to_write.clear() self._file_parts_to_finish.clear() parser.finalize() except MultiPartException as exc: # Close all the files if there was an error. for file in self._files_to_close_on_error: file.close() raise exc return FormData(self.items) ================================================ FILE: starlette/middleware/__init__.py ================================================ from __future__ import annotations from collections.abc import Awaitable, Callable, Iterator from typing import Any, ParamSpec, Protocol P = ParamSpec("P") _Scope = Any _Receive = Callable[[], Awaitable[Any]] _Send = Callable[[Any], Awaitable[None]] # Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref` # we need to define a more permissive version of ASGIApp that doesn't cause type errors. _ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]] class _MiddlewareFactory(Protocol[P]): def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover class Middleware: def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None: self.cls = cls self.args = args self.kwargs = kwargs def __iter__(self) -> Iterator[Any]: as_tuple = (self.cls, self.args, self.kwargs) return iter(as_tuple) def __repr__(self) -> str: class_name = self.__class__.__name__ args_strings = [f"{value!r}" for value in self.args] option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] name = getattr(self.cls, "__name__", "") args_repr = ", ".join([name] + args_strings + option_strings) return f"{class_name}({args_repr})" ================================================ FILE: starlette/middleware/authentication.py ================================================ from __future__ import annotations from collections.abc import Callable from starlette.authentication import ( AuthCredentials, AuthenticationBackend, AuthenticationError, UnauthenticatedUser, ) from starlette.requests import HTTPConnection from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send class AuthenticationMiddleware: def __init__( self, app: ASGIApp, backend: AuthenticationBackend, on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None, ) -> None: self.app = app self.backend = backend self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = ( on_error if on_error is not None else self.default_on_error ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ["http", "websocket"]: await self.app(scope, receive, send) return conn = HTTPConnection(scope) try: auth_result = await self.backend.authenticate(conn) except AuthenticationError as exc: response = self.on_error(conn, exc) if scope["type"] == "websocket": await send({"type": "websocket.close", "code": 1000}) else: await response(scope, receive, send) return if auth_result is None: auth_result = AuthCredentials(), UnauthenticatedUser() scope["auth"], scope["user"] = auth_result await self.app(scope, receive, send) @staticmethod def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: return PlainTextResponse(str(exc), status_code=400) ================================================ FILE: starlette/middleware/base.py ================================================ from __future__ import annotations from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping from typing import Any, TypeVar import anyio from starlette._utils import collapse_excgroups from starlette.requests import ClientDisconnect, Request from starlette.responses import Response from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = Callable[[Request], Awaitable[Response]] DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]] BodyStreamGenerator = AsyncGenerator[bytes | MutableMapping[str, Any], None] AsyncContentStream = AsyncIterable[str | bytes | memoryview | MutableMapping[str, Any]] T = TypeVar("T") class _CachedRequest(Request): """ If the user calls Request.body() from their dispatch function we cache the entire request body in memory and pass that to downstream middlewares, but if they call Request.stream() then all we do is send an empty body so that downstream things don't hang forever. """ def __init__(self, scope: Scope, receive: Receive): super().__init__(scope, receive) self._wrapped_rcv_disconnected = False self._wrapped_rcv_consumed = False self._wrapped_rc_stream = self.stream() async def wrapped_receive(self) -> Message: # wrapped_rcv state 1: disconnected if self._wrapped_rcv_disconnected: # we've already sent a disconnect to the downstream app # we don't need to wait to get another one # (although most ASGI servers will just keep sending it) return {"type": "http.disconnect"} # wrapped_rcv state 1: consumed but not yet disconnected if self._wrapped_rcv_consumed: # since the downstream app has consumed us all that is left # is to send it a disconnect if self._is_disconnected: # the middleware has already seen the disconnect # since we know the client is disconnected no need to wait # for the message self._wrapped_rcv_disconnected = True return {"type": "http.disconnect"} # we don't know yet if the client is disconnected or not # so we'll wait until we get that message msg = await self.receive() if msg["type"] != "http.disconnect": # pragma: no cover # at this point a disconnect is all that we should be receiving # if we get something else, things went wrong somewhere raise RuntimeError(f"Unexpected message received: {msg['type']}") self._wrapped_rcv_disconnected = True return msg # wrapped_rcv state 3: not yet consumed if getattr(self, "_body", None) is not None: # body() was called, we return it even if the client disconnected self._wrapped_rcv_consumed = True return { "type": "http.request", "body": self._body, "more_body": False, } elif self._stream_consumed: # stream() was called to completion # return an empty body so that downstream apps don't hang # waiting for a disconnect self._wrapped_rcv_consumed = True return { "type": "http.request", "body": b"", "more_body": False, } else: # body() was never called and stream() wasn't consumed try: stream = self.stream() chunk = await stream.__anext__() self._wrapped_rcv_consumed = self._stream_consumed return { "type": "http.request", "body": chunk, "more_body": not self._stream_consumed, } except ClientDisconnect: self._wrapped_rcv_disconnected = True return {"type": "http.disconnect"} class BaseHTTPMiddleware: def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: self.app = app self.dispatch_func = self.dispatch if dispatch is None else dispatch async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return request = _CachedRequest(scope, receive) wrapped_receive = request.wrapped_receive response_sent = anyio.Event() app_exc: Exception | None = None exception_already_raised = False async def call_next(request: Request) -> Response: async def receive_or_disconnect() -> Message: if response_sent.is_set(): return {"type": "http.disconnect"} async with anyio.create_task_group() as task_group: async def wrap(func: Callable[[], Awaitable[T]]) -> T: result = await func() task_group.cancel_scope.cancel() return result task_group.start_soon(wrap, response_sent.wait) message = await wrap(wrapped_receive) if response_sent.is_set(): return {"type": "http.disconnect"} return message async def send_no_error(message: Message) -> None: try: await send_stream.send(message) except anyio.BrokenResourceError: # recv_stream has been closed, i.e. response_sent has been set. return async def coro() -> None: nonlocal app_exc with send_stream: try: await self.app(scope, receive_or_disconnect, send_no_error) except Exception as exc: app_exc = exc task_group.start_soon(coro) try: message = await recv_stream.receive() info = message.get("info", None) if message["type"] == "http.response.debug" and info is not None: message = await recv_stream.receive() except anyio.EndOfStream: if app_exc is not None: nonlocal exception_already_raised exception_already_raised = True # Prevent `anyio.EndOfStream` from polluting app exception context. # If both cause and context are None then the context is suppressed # and `anyio.EndOfStream` is not present in the exception traceback. # If exception cause is not None then it is propagated with # reraising here. # If exception has no cause but has context set then the context is # propagated as a cause with the reraise. This is necessary in order # to prevent `anyio.EndOfStream` from polluting the exception # context. raise app_exc from app_exc.__cause__ or app_exc.__context__ raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" async def body_stream() -> BodyStreamGenerator: async for message in recv_stream: if message["type"] == "http.response.pathsend": yield message break assert message["type"] == "http.response.body", f"Unexpected message: {message}" body = message.get("body", b"") if body: yield body if not message.get("more_body", False): break response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) response.raw_headers = message["headers"] return response streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() send_stream, recv_stream = streams with recv_stream, send_stream, collapse_excgroups(): async with anyio.create_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() recv_stream.close() if app_exc is not None and not exception_already_raised: raise app_exc async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: raise NotImplementedError() # pragma: no cover class _StreamingResponse(Response): def __init__( self, content: AsyncContentStream, status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, info: Mapping[str, Any] | None = None, ) -> None: self.info = info self.body_iterator = content self.status_code = status_code self.media_type = media_type self.init_headers(headers) self.background = None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.info is not None: await send({"type": "http.response.debug", "info": self.info}) await send( { "type": "http.response.start", "status": self.status_code, "headers": self.raw_headers, } ) should_close_body = True async for chunk in self.body_iterator: if isinstance(chunk, dict): # We got an ASGI message which is not response body (eg: pathsend) should_close_body = False await send(chunk) continue await send({"type": "http.response.body", "body": chunk, "more_body": True}) if should_close_body: await send({"type": "http.response.body", "body": b"", "more_body": False}) if self.background: await self.background() ================================================ FILE: starlette/middleware/cors.py ================================================ from __future__ import annotations import functools import re from collections.abc import Sequence from starlette.datastructures import Headers, MutableHeaders from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} class CORSMiddleware: def __init__( self, app: ASGIApp, allow_origins: Sequence[str] = (), allow_methods: Sequence[str] = ("GET",), allow_headers: Sequence[str] = (), allow_credentials: bool = False, allow_origin_regex: str | None = None, allow_private_network: bool = False, expose_headers: Sequence[str] = (), max_age: int = 600, ) -> None: if "*" in allow_methods: allow_methods = ALL_METHODS compiled_allow_origin_regex = None if allow_origin_regex is not None: compiled_allow_origin_regex = re.compile(allow_origin_regex) allow_all_origins = "*" in allow_origins allow_all_headers = "*" in allow_headers preflight_explicit_allow_origin = not allow_all_origins or allow_credentials simple_headers: dict[str, str] = {} if allow_all_origins: simple_headers["Access-Control-Allow-Origin"] = "*" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" if expose_headers: simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) preflight_headers: dict[str, str] = {} if preflight_explicit_allow_origin: # The origin value will be set in preflight_response() if it is allowed. preflight_headers["Vary"] = "Origin" else: preflight_headers["Access-Control-Allow-Origin"] = "*" preflight_headers.update( { "Access-Control-Allow-Methods": ", ".join(allow_methods), "Access-Control-Max-Age": str(max_age), } ) allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) if allow_headers and not allow_all_headers: preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) if allow_credentials: preflight_headers["Access-Control-Allow-Credentials"] = "true" self.app = app self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] self.allow_all_origins = allow_all_origins self.allow_all_headers = allow_all_headers self.allow_credentials = allow_credentials self.preflight_explicit_allow_origin = preflight_explicit_allow_origin self.allow_origin_regex = compiled_allow_origin_regex self.allow_private_network = allow_private_network self.simple_headers = simple_headers self.preflight_headers = preflight_headers async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": # pragma: no cover await self.app(scope, receive, send) return method = scope["method"] headers = Headers(scope=scope) origin = headers.get("origin") if origin is None: await self.app(scope, receive, send) return if method == "OPTIONS" and "access-control-request-method" in headers: response = self.preflight_response(request_headers=headers) await response(scope, receive, send) return await self.simple_response(scope, receive, send, request_headers=headers) def is_allowed_origin(self, origin: str) -> bool: if self.allow_all_origins: return True if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): return True return origin in self.allow_origins def preflight_response(self, request_headers: Headers) -> Response: requested_origin = request_headers["origin"] requested_method = request_headers["access-control-request-method"] requested_headers = request_headers.get("access-control-request-headers") requested_private_network = request_headers.get("access-control-request-private-network") headers = dict(self.preflight_headers) failures: list[str] = [] if self.is_allowed_origin(origin=requested_origin): if self.preflight_explicit_allow_origin: # The "else" case is already accounted for in self.preflight_headers # and the value would be "*". headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") if requested_method not in self.allow_methods: failures.append("method") # If we allow all headers, then we have to mirror back any requested # headers in the response. if self.allow_all_headers and requested_headers is not None: headers["Access-Control-Allow-Headers"] = requested_headers elif requested_headers is not None: for header in [h.lower() for h in requested_headers.split(",")]: if header.strip() not in self.allow_headers: failures.append("headers") break if requested_private_network is not None: if self.allow_private_network: headers["Access-Control-Allow-Private-Network"] = "true" else: failures.append("private-network") # We don't strictly need to use 400 responses here, since its up to # the browser to enforce the CORS policy, but its more informative # if we do. if failures: failure_text = "Disallowed CORS " + ", ".join(failures) return PlainTextResponse(failure_text, status_code=400, headers=headers) return PlainTextResponse("OK", status_code=200, headers=headers) async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None: send = functools.partial(self.send, send=send, request_headers=request_headers) await self.app(scope, receive, send) async def send(self, message: Message, send: Send, request_headers: Headers) -> None: if message["type"] != "http.response.start": await send(message) return message.setdefault("headers", []) headers = MutableHeaders(scope=message) headers.update(self.simple_headers) origin = request_headers["Origin"] # If credentials are allowed, then we must respond with the specific origin instead of '*'. if self.allow_all_origins and self.allow_credentials: self.allow_explicit_origin(headers, origin) # If we only allow specific origins, then we have to mirror back the Origin header in the response. elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): self.allow_explicit_origin(headers, origin) await send(message) @staticmethod def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: headers["Access-Control-Allow-Origin"] = origin headers.add_vary_header("Origin") ================================================ FILE: starlette/middleware/errors.py ================================================ from __future__ import annotations import html import inspect import sys import traceback from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import HTMLResponse, PlainTextResponse, Response from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send STYLES = """ p { color: #211c1c; } .traceback-container { border: 1px solid #038BB8; } .traceback-title { background-color: #038BB8; color: lemonchiffon; padding: 12px; font-size: 20px; margin-top: 0px; } .frame-line { padding-left: 10px; font-family: monospace; } .frame-filename { font-family: monospace; } .center-line { background-color: #038BB8; color: #f9f6e1; padding: 5px 0px 5px 5px; } .lineno { margin-right: 5px; } .frame-title { font-weight: unset; padding: 10px 10px 10px 10px; background-color: #E4F4FD; margin-right: 10px; color: #191f21; font-size: 17px; border: 1px solid #c7dce8; } .collapse-btn { float: right; padding: 0px 5px 1px 5px; border: solid 1px #96aebb; cursor: pointer; } .collapsed { display: none; } .source-code { font-family: courier; font-size: small; padding-bottom: 10px; } """ JS = """ """ TEMPLATE = """ Starlette Debugger

    500 Server Error

    {error}

    Traceback

    {exc_html}
    {js} """ FRAME_TEMPLATE = """

    File {frame_filename}, line {frame_lineno}, in {frame_name} {collapse_button}

    {code_context}
    """ # noqa: E501 LINE = """

    {lineno}. {line}

    """ CENTER_LINE = """

    {lineno}. {line}

    """ class ServerErrorMiddleware: """ Handles returning 500 responses when a server error occurs. If 'debug' is set, then traceback responses will be returned, otherwise the designated 'handler' will be called. This middleware class should generally be used to wrap *everything* else up, so that unhandled exceptions anywhere in the stack always result in an appropriate 500 response. """ def __init__( self, app: ASGIApp, handler: ExceptionHandler | None = None, debug: bool = False, ) -> None: self.app = app self.handler = handler self.debug = debug async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return response_started = False async def _send(message: Message) -> None: nonlocal response_started, send if message["type"] == "http.response.start": response_started = True await send(message) try: await self.app(scope, receive, _send) except Exception as exc: request = Request(scope) if self.debug: # In debug mode, return traceback responses. response = self.debug_response(request, exc) elif self.handler is None: # Use our default 500 error handler. response = self.error_response(request, exc) else: # Use an installed 500 error handler. if is_async_callable(self.handler): response = await self.handler(request, exc) else: response = await run_in_threadpool(self.handler, request, exc) if not response_started: await response(scope, receive, send) # We always continue to raise the exception. # This allows servers to log the error, or allows test clients # to optionally raise the error within the test case. raise exc def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str: values = { # HTML escape - line could contain < or > "line": html.escape(line).replace(" ", " "), "lineno": (frame_lineno - frame_index) + index, } if index != frame_index: return LINE.format(**values) return CENTER_LINE.format(**values) def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str: code_context = "".join( self.format_line( index, line, frame.lineno, frame.index, # type: ignore[arg-type] ) for index, line in enumerate(frame.code_context or []) ) values = { # HTML escape - filename could contain < or >, especially if it's a virtual # file e.g. in the REPL "frame_filename": html.escape(frame.filename), "frame_lineno": frame.lineno, # HTML escape - if you try very hard it's possible to name a function with < # or > "frame_name": html.escape(frame.function), "code_context": code_context, "collapsed": "collapsed" if is_collapsed else "", "collapse_button": "+" if is_collapsed else "‒", } return FRAME_TEMPLATE.format(**values) def generate_html(self, exc: Exception, limit: int = 7) -> str: traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True) exc_html = "" is_collapsed = False exc_traceback = exc.__traceback__ if exc_traceback is not None: frames = inspect.getinnerframes(exc_traceback, limit) for frame in reversed(frames): exc_html += self.generate_frame_html(frame, is_collapsed) is_collapsed = True if sys.version_info >= (3, 13): # pragma: no cover exc_type_str = traceback_obj.exc_type_str else: # pragma: no cover exc_type_str = traceback_obj.exc_type.__name__ # escape error class and text error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}" return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html) def generate_plain_text(self, exc: Exception) -> str: return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) def debug_response(self, request: Request, exc: Exception) -> Response: accept = request.headers.get("accept", "") if "text/html" in accept: content = self.generate_html(exc) return HTMLResponse(content, status_code=500) content = self.generate_plain_text(exc) return PlainTextResponse(content, status_code=500) def error_response(self, request: Request, exc: Exception) -> Response: return PlainTextResponse("Internal Server Error", status_code=500) ================================================ FILE: starlette/middleware/exceptions.py ================================================ from __future__ import annotations from collections.abc import Mapping from typing import Any from starlette._exception_handler import ( ExceptionHandlers, StatusHandlers, wrap_app_handling_exceptions, ) from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send from starlette.websockets import WebSocket class ExceptionMiddleware: def __init__( self, app: ASGIApp, handlers: Mapping[Any, ExceptionHandler] | None = None, debug: bool = False, ) -> None: self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. self._status_handlers: StatusHandlers = {} self._exception_handlers: ExceptionHandlers = { HTTPException: self.http_exception, WebSocketException: self.websocket_exception, } if handlers is not None: # pragma: no branch for key, value in handlers.items(): self.add_exception_handler(key, value) def add_exception_handler( self, exc_class_or_status_code: int | type[Exception], handler: ExceptionHandler, ) -> None: if isinstance(exc_class_or_status_code, int): self._status_handlers[exc_class_or_status_code] = handler else: assert issubclass(exc_class_or_status_code, Exception) self._exception_handlers[exc_class_or_status_code] = handler async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return scope["starlette.exception_handlers"] = ( self._exception_handlers, self._status_handlers, ) conn: Request | WebSocket if scope["type"] == "http": conn = Request(scope, receive, send) else: conn = WebSocket(scope, receive, send) await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) async def http_exception(self, request: Request, exc: Exception) -> Response: assert isinstance(exc, HTTPException) if exc.status_code in {204, 304}: return Response(status_code=exc.status_code, headers=exc.headers) return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: assert isinstance(exc, WebSocketException) await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover ================================================ FILE: starlette/middleware/gzip.py ================================================ import gzip import io from typing import NoReturn from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",) class GZipMiddleware: def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: self.app = app self.minimum_size = minimum_size self.compresslevel = compresslevel async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": # pragma: no cover await self.app(scope, receive, send) return headers = Headers(scope=scope) responder: ASGIApp if "gzip" in headers.get("Accept-Encoding", ""): responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) else: responder = IdentityResponder(self.app, self.minimum_size) await responder(scope, receive, send) class IdentityResponder: content_encoding: str def __init__(self, app: ASGIApp, minimum_size: int) -> None: self.app = app self.minimum_size = minimum_size self.send: Send = unattached_send self.initial_message: Message = {} self.started = False self.content_encoding_set = False self.content_type_is_excluded = False async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.send = send await self.app(scope, receive, self.send_with_compression) async def send_with_compression(self, message: Message) -> None: message_type = message["type"] if message_type == "http.response.start": # Don't send the initial message until we've determined how to # modify the outgoing headers correctly. self.initial_message = message headers = Headers(raw=self.initial_message["headers"]) self.content_encoding_set = "content-encoding" in headers self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES) elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded): if not self.started: self.started = True await self.send(self.initial_message) await self.send(message) elif message_type == "http.response.body" and not self.started: self.started = True body = message.get("body", b"") more_body = message.get("more_body", False) if len(body) < self.minimum_size and not more_body: # Don't apply compression to small outgoing responses. await self.send(self.initial_message) await self.send(message) elif not more_body: # Standard response. body = self.apply_compression(body, more_body=False) headers = MutableHeaders(raw=self.initial_message["headers"]) headers.add_vary_header("Accept-Encoding") if body != message["body"]: headers["Content-Encoding"] = self.content_encoding headers["Content-Length"] = str(len(body)) message["body"] = body await self.send(self.initial_message) await self.send(message) else: # Initial body in streaming response. body = self.apply_compression(body, more_body=True) headers = MutableHeaders(raw=self.initial_message["headers"]) headers.add_vary_header("Accept-Encoding") if body != message["body"]: headers["Content-Encoding"] = self.content_encoding del headers["Content-Length"] message["body"] = body await self.send(self.initial_message) await self.send(message) elif message_type == "http.response.body": # Remaining body in streaming response. body = message.get("body", b"") more_body = message.get("more_body", False) message["body"] = self.apply_compression(body, more_body=more_body) await self.send(message) elif message_type == "http.response.pathsend": # pragma: no branch # Don't apply GZip to pathsend responses await self.send(self.initial_message) await self.send(message) def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: """Apply compression on the response body. If more_body is False, any compression file should be closed. If it isn't, it won't be closed automatically until all background tasks complete. """ return body class GZipResponder(IdentityResponder): content_encoding = "gzip" def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: super().__init__(app, minimum_size) self.gzip_buffer = io.BytesIO() self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: with self.gzip_buffer, self.gzip_file: await super().__call__(scope, receive, send) def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: self.gzip_file.write(body) if not more_body: self.gzip_file.close() body = self.gzip_buffer.getvalue() self.gzip_buffer.seek(0) self.gzip_buffer.truncate() return body async def unattached_send(message: Message) -> NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover ================================================ FILE: starlette/middleware/httpsredirect.py ================================================ from starlette.datastructures import URL from starlette.responses import RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send class HTTPSRedirectMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"): url = URL(scope=scope) redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme] netloc = url.hostname if url.port in (80, 443) else url.netloc url = url.replace(scheme=redirect_scheme, netloc=netloc) response = RedirectResponse(url, status_code=307) await response(scope, receive, send) else: await self.app(scope, receive, send) ================================================ FILE: starlette/middleware/sessions.py ================================================ from __future__ import annotations import json import typing from base64 import b64decode, b64encode from typing import Literal import itsdangerous from itsdangerous.exc import BadSignature from starlette.datastructures import MutableHeaders, Secret from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Message, Receive, Scope, Send class SessionMiddleware: def __init__( self, app: ASGIApp, secret_key: str | Secret, session_cookie: str = "session", max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds path: str = "/", same_site: Literal["lax", "strict", "none"] = "lax", https_only: bool = False, domain: str | None = None, ) -> None: self.app = app self.signer = itsdangerous.TimestampSigner(str(secret_key)) self.session_cookie = session_cookie self.max_age = max_age self.path = path self.security_flags = "httponly; samesite=" + same_site if https_only: # Secure flag can be used with HTTPS only self.security_flags += "; secure" if domain is not None: self.security_flags += f"; domain={domain}" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): # pragma: no cover await self.app(scope, receive, send) return connection = HTTPConnection(scope) initial_session_was_empty = True if self.session_cookie in connection.cookies: data = connection.cookies[self.session_cookie].encode("utf-8") try: data = self.signer.unsign(data, max_age=self.max_age) scope["session"] = Session(json.loads(b64decode(data))) initial_session_was_empty = False except BadSignature: scope["session"] = Session() else: scope["session"] = Session() async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": session: Session = scope["session"] headers = MutableHeaders(scope=message) if session.accessed: headers.add_vary_header("Cookie") if session.modified and session: # We have session data to persist. data = b64encode(json.dumps(session).encode("utf-8")) data = self.signer.sign(data) header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( session_cookie=self.session_cookie, data=data.decode("utf-8"), path=self.path, max_age=f"Max-Age={self.max_age}; " if self.max_age else "", security_flags=self.security_flags, ) headers.append("Set-Cookie", header_value) elif session.modified and not initial_session_was_empty: # The session has been cleared. header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( session_cookie=self.session_cookie, data="null", path=self.path, expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ", security_flags=self.security_flags, ) headers.append("Set-Cookie", header_value) await send(message) await self.app(scope, receive, send_wrapper) class Session(dict[str, typing.Any]): accessed: bool = False modified: bool = False def mark_accessed(self) -> None: self.accessed = True def mark_modified(self) -> None: self.accessed = True self.modified = True def __setitem__(self, key: str, value: typing.Any) -> None: self.mark_modified() super().__setitem__(key, value) def __delitem__(self, key: str) -> None: self.mark_modified() super().__delitem__(key) def clear(self) -> None: self.mark_modified() super().clear() def pop(self, key: str, *args: typing.Any) -> typing.Any: self.modified = self.modified or key in self return super().pop(key, *args) def setdefault(self, key: str, default: typing.Any = None) -> typing.Any: if key not in self: self.mark_modified() return super().setdefault(key, default) def update(self, *args: typing.Any, **kwargs: typing.Any) -> None: self.mark_modified() super().update(*args, **kwargs) ================================================ FILE: starlette/middleware/trustedhost.py ================================================ from __future__ import annotations from collections.abc import Sequence from starlette.datastructures import URL, Headers from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." class TrustedHostMiddleware: def __init__( self, app: ASGIApp, allowed_hosts: Sequence[str] | None = None, www_redirect: bool = True, ) -> None: if allowed_hosts is None: allowed_hosts = ["*"] for pattern in allowed_hosts: assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD if pattern.startswith("*") and pattern != "*": assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD self.app = app self.allowed_hosts = list(allowed_hosts) self.allow_any = "*" in allowed_hosts self.www_redirect = www_redirect async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.allow_any or scope["type"] not in ( "http", "websocket", ): # pragma: no cover await self.app(scope, receive, send) return headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] is_valid_host = False found_www_redirect = False for pattern in self.allowed_hosts: if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): is_valid_host = True break elif "www." + host == pattern: found_www_redirect = True if is_valid_host: await self.app(scope, receive, send) else: response: Response if found_www_redirect and self.www_redirect: url = URL(scope=scope) redirect_url = url.replace(netloc="www." + url.netloc) response = RedirectResponse(url=str(redirect_url)) else: response = PlainTextResponse("Invalid host header", status_code=400) await response(scope, receive, send) ================================================ FILE: starlette/middleware/wsgi.py ================================================ from __future__ import annotations import io import math import sys import warnings from collections.abc import Callable, MutableMapping from typing import Any import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette.types import Receive, Scope, Send warnings.warn( "starlette.middleware.wsgi is deprecated and will be removed in a future release. " "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.", DeprecationWarning, stacklevel=2, ) def build_environ(scope: Scope, body: bytes) -> dict[str, Any]: """ Builds a scope and request body into a WSGI environ object. """ script_name = scope.get("root_path", "").encode("utf8").decode("latin1") path_info = scope["path"].encode("utf8").decode("latin1") if path_info.startswith(script_name): path_info = path_info[len(script_name) :] environ = { "REQUEST_METHOD": scope["method"], "SCRIPT_NAME": script_name, "PATH_INFO": path_info, "QUERY_STRING": scope["query_string"].decode("ascii"), "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", "wsgi.version": (1, 0), "wsgi.url_scheme": scope.get("scheme", "http"), "wsgi.input": io.BytesIO(body), "wsgi.errors": sys.stdout, "wsgi.multithread": True, "wsgi.multiprocess": True, "wsgi.run_once": False, } # Get server name and port - required in WSGI, not in ASGI server = scope.get("server") or ("localhost", 80) environ["SERVER_NAME"] = server[0] environ["SERVER_PORT"] = server[1] # Get client IP address if scope.get("client"): environ["REMOTE_ADDR"] = scope["client"][0] # Go through headers and make them into environ entries for name, value in scope.get("headers", []): name = name.decode("latin1") if name == "content-length": corrected_name = "CONTENT_LENGTH" elif name == "content-type": corrected_name = "CONTENT_TYPE" else: corrected_name = f"HTTP_{name}".upper().replace("-", "_") # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in # case value = value.decode("latin1") if corrected_name in environ: value = environ[corrected_name] + "," + value environ[corrected_name] = value return environ class WSGIMiddleware: def __init__(self, app: Callable[..., Any]) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == "http" responder = WSGIResponder(self.app, scope) await responder(receive, send) class WSGIResponder: stream_send: ObjectSendStream[MutableMapping[str, Any]] stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] def __init__(self, app: Callable[..., Any], scope: Scope) -> None: self.app = app self.scope = scope self.status = None self.response_headers = None self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) self.response_started = False self.exc_info: Any = None async def __call__(self, receive: Receive, send: Send) -> None: body = b"" more_body = True while more_body: message = await receive() body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) async with anyio.create_task_group() as task_group: task_group.start_soon(self.sender, send) async with self.stream_send: await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) if self.exc_info is not None: raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) async def sender(self, send: Send) -> None: async with self.stream_receive: async for message in self.stream_receive: await send(message) def start_response( self, status: str, response_headers: list[tuple[str, str]], exc_info: Any = None, ) -> None: self.exc_info = exc_info if not self.response_started: # pragma: no branch self.response_started = True status_code_string, _ = status.split(" ", 1) status_code = int(status_code_string) headers = [ (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] anyio.from_thread.run( self.stream_send.send, { "type": "http.response.start", "status": status_code, "headers": headers, }, ) def wsgi( self, environ: dict[str, Any], start_response: Callable[..., Any], ) -> None: for chunk in self.app(environ, start_response): anyio.from_thread.run( self.stream_send.send, {"type": "http.response.body", "body": chunk, "more_body": True}, ) anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) ================================================ FILE: starlette/py.typed ================================================ ================================================ FILE: starlette/requests.py ================================================ from __future__ import annotations import json import sys from collections.abc import AsyncGenerator, Iterator, Mapping from http import cookies as http_cookies from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast import anyio from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.exceptions import HTTPException from starlette.formparsers import FormParser, MultiPartException, MultiPartParser from starlette.types import Message, Receive, Scope, Send if TYPE_CHECKING: from python_multipart.multipart import parse_options_header from starlette.applications import Starlette from starlette.middleware.sessions import Session from starlette.routing import Router else: try: try: from python_multipart.multipart import parse_options_header except ModuleNotFoundError: # pragma: no cover from multipart.multipart import parse_options_header except ModuleNotFoundError: # pragma: no cover parse_options_header = None if sys.version_info >= (3, 13): # pragma: no cover from typing import TypeVar else: # pragma: no cover from typing_extensions import TypeVar SERVER_PUSH_HEADERS_TO_COPY = { "accept", "accept-encoding", "accept-language", "cache-control", "user-agent", } def cookie_parser(cookie_string: str) -> dict[str, str]: """ This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. It attempts to mimic browser cookie parsing behavior: browsers and web servers frequently disregard the spec (RFC 6265) when setting and reading cookies, so we attempt to suit the common scenarios here. This function has been adapted from Django 3.1.0. Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based on an outdated spec and will fail on lots of input we want to support """ cookie_dict: dict[str, str] = {} for chunk in cookie_string.split(";"): if "=" in chunk: key, val = chunk.split("=", 1) else: # Assume an empty name per # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 key, val = "", chunk key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. cookie_dict[key] = http_cookies._unquote(val) return cookie_dict class ClientDisconnect(Exception): pass StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State) class HTTPConnection(Mapping[str, Any], Generic[StateT]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. """ def __init__(self, scope: Scope, receive: Receive | None = None) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope def __getitem__(self, key: str) -> Any: return self.scope[key] def __iter__(self) -> Iterator[str]: return iter(self.scope) def __len__(self) -> int: return len(self.scope) # Don't use the `abc.Mapping.__eq__` implementation. # Connection instances should never be considered equal # unless `self is other`. __eq__ = object.__eq__ __hash__ = object.__hash__ @property def app(self) -> Any: return self.scope["app"] @property def url(self) -> URL: if not hasattr(self, "_url"): # pragma: no branch self._url = URL(scope=self.scope) return self._url @property def base_url(self) -> URL: if not hasattr(self, "_base_url"): base_url_scope = dict(self.scope) # This is used by request.url_for, it might be used inside a Mount which # would have its own child scope with its own root_path, but the base URL # for url_for should still be the top level app root path. app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) path = app_root_path if not path.endswith("/"): path += "/" base_url_scope["path"] = path base_url_scope["query_string"] = b"" base_url_scope["root_path"] = app_root_path self._base_url = URL(scope=base_url_scope) return self._base_url @property def headers(self) -> Headers: if not hasattr(self, "_headers"): self._headers = Headers(scope=self.scope) return self._headers @property def query_params(self) -> QueryParams: if not hasattr(self, "_query_params"): # pragma: no branch self._query_params = QueryParams(self.scope["query_string"]) return self._query_params @property def path_params(self) -> dict[str, Any]: return self.scope.get("path_params", {}) @property def cookies(self) -> dict[str, str]: if not hasattr(self, "_cookies"): cookies: dict[str, str] = {} cookie_headers = self.headers.getlist("cookie") for header in cookie_headers: cookies.update(cookie_parser(header)) self._cookies = cookies return self._cookies @property def client(self) -> Address | None: # client is a 2 item tuple of (host, port), None if missing host_port = self.scope.get("client") if host_port is not None: return Address(*host_port) return None @property def session(self) -> dict[str, Any]: assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" session: Session = self.scope["session"] # We keep the hasattr in case people actually use their own `SessionMiddleware` implementation. if hasattr(session, "mark_accessed"): # pragma: no branch session.mark_accessed() return session @property def auth(self) -> Any: assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" return self.scope["auth"] @property def user(self) -> Any: assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" return self.scope["user"] @property def state(self) -> StateT: if not hasattr(self, "_state"): # Ensure 'state' has an empty dict if it's not already populated. self.scope.setdefault("state", {}) # Create a state instance with a reference to the dict in which it should # store info self._state = State(self.scope["state"]) return cast(StateT, self._state) def url_for(self, name: str, /, **path_params: Any) -> URL: url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") if url_path_provider is None: raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") url_path = url_path_provider.url_path_for(name, **path_params) return url_path.make_absolute_url(base_url=self.base_url) async def empty_receive() -> NoReturn: raise RuntimeError("Receive channel has not been made available") async def empty_send(message: Message) -> NoReturn: raise RuntimeError("Send channel has not been made available") class Request(HTTPConnection[StateT]): _form: FormData | None def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): super().__init__(scope) assert scope["type"] == "http" self._receive = receive self._send = send self._stream_consumed = False self._is_disconnected = False self._form = None @property def method(self) -> str: return cast(str, self.scope["method"]) @property def receive(self) -> Receive: return self._receive async def stream(self) -> AsyncGenerator[bytes, None]: if hasattr(self, "_body"): yield self._body yield b"" return if self._stream_consumed: raise RuntimeError("Stream consumed") while not self._stream_consumed: message = await self._receive() if message["type"] == "http.request": body = message.get("body", b"") if not message.get("more_body", False): self._stream_consumed = True if body: yield body elif message["type"] == "http.disconnect": # pragma: no branch self._is_disconnected = True raise ClientDisconnect() yield b"" async def body(self) -> bytes: if not hasattr(self, "_body"): chunks: list[bytes] = [] async for chunk in self.stream(): chunks.append(chunk) self._body = b"".join(chunks) return self._body async def json(self) -> Any: if not hasattr(self, "_json"): # pragma: no branch body = await self.body() self._json = json.loads(body) return self._json async def _get_form( self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_part_size: int = 1024 * 1024, ) -> FormData: if self._form is None: # pragma: no branch assert parse_options_header is not None, ( "The `python-multipart` library must be installed to use form parsing." ) content_type_header = self.headers.get("Content-Type") content_type: bytes content_type, _ = parse_options_header(content_type_header) if content_type == b"multipart/form-data": try: multipart_parser = MultiPartParser( self.headers, self.stream(), max_files=max_files, max_fields=max_fields, max_part_size=max_part_size, ) self._form = await multipart_parser.parse() except MultiPartException as exc: if "app" in self.scope: raise HTTPException(status_code=400, detail=exc.message) raise exc elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.headers, self.stream()) self._form = await form_parser.parse() else: self._form = FormData() return self._form def form( self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_part_size: int = 1024 * 1024, ) -> AwaitableOrContextManager[FormData]: return AwaitableOrContextManagerWrapper( self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) ) async def close(self) -> None: if self._form is not None: # pragma: no branch await self._form.close() async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} # If message isn't immediately available, move on with anyio.CancelScope() as cs: cs.cancel() message = await self._receive() if message.get("type") == "http.disconnect": self._is_disconnected = True return self._is_disconnected async def send_push_promise(self, path: str) -> None: if "http.response.push" in self.scope.get("extensions", {}): raw_headers: list[tuple[bytes, bytes]] = [] for name in SERVER_PUSH_HEADERS_TO_COPY: for value in self.headers.getlist(name): raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) await self._send({"type": "http.response.push", "path": path, "headers": raw_headers}) ================================================ FILE: starlette/responses.py ================================================ from __future__ import annotations import hashlib import http.cookies import json import os import stat import sys from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence from datetime import datetime from email.utils import format_datetime, formatdate from functools import partial from mimetypes import guess_type from secrets import token_hex from typing import Any, Literal from urllib.parse import quote import anyio import anyio.to_thread from starlette._utils import collapse_excgroups from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders from starlette.requests import ClientDisconnect from starlette.types import Message, Receive, Scope, Send class Response: media_type = None charset = "utf-8" def __init__( self, content: Any = None, status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: self.status_code = status_code if media_type is not None: self.media_type = media_type self.background = background self.body = self.render(content) self.init_headers(headers) def render(self, content: Any) -> bytes | memoryview: if content is None: return b"" if isinstance(content, bytes | memoryview): return content return content.encode(self.charset) # type: ignore def init_headers(self, headers: Mapping[str, str] | None = None) -> None: if headers is None: raw_headers: list[tuple[bytes, bytes]] = [] populate_content_length = True populate_content_type = True else: raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] keys = [h[0] for h in raw_headers] populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys body = getattr(self, "body", None) if ( body is not None and populate_content_length and not (self.status_code < 200 or self.status_code in (204, 304)) ): content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) content_type = self.media_type if content_type is not None and populate_content_type: if content_type.startswith("text/") and "charset=" not in content_type.lower(): content_type += "; charset=" + self.charset raw_headers.append((b"content-type", content_type.encode("latin-1"))) self.raw_headers = raw_headers @property def headers(self) -> MutableHeaders: if not hasattr(self, "_headers"): self._headers = MutableHeaders(raw=self.raw_headers) return self._headers def set_cookie( self, key: str, value: str = "", max_age: int | None = None, expires: datetime | str | int | None = None, path: str | None = "/", domain: str | None = None, secure: bool = False, httponly: bool = False, samesite: Literal["lax", "strict", "none"] | None = "lax", partitioned: bool = False, ) -> None: cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() cookie[key] = value if max_age is not None: cookie[key]["max-age"] = max_age if expires is not None: if isinstance(expires, datetime): cookie[key]["expires"] = format_datetime(expires, usegmt=True) else: cookie[key]["expires"] = expires if path is not None: cookie[key]["path"] = path if domain is not None: cookie[key]["domain"] = domain if secure: cookie[key]["secure"] = True if httponly: cookie[key]["httponly"] = True if samesite is not None: assert samesite.lower() in [ "strict", "lax", "none", ], "samesite must be either 'strict', 'lax' or 'none'" cookie[key]["samesite"] = samesite if partitioned: if sys.version_info < (3, 14): raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover cookie[key]["partitioned"] = True # pragma: no cover cookie_val = cookie.output(header="").strip() self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) def delete_cookie( self, key: str, path: str = "/", domain: str | None = None, secure: bool = False, httponly: bool = False, samesite: Literal["lax", "strict", "none"] | None = "lax", ) -> None: self.set_cookie( key, max_age=0, expires=0, path=path, domain=domain, secure=secure, httponly=httponly, samesite=samesite, ) def _wrap_websocket_denial_send(self, send: Send) -> Send: async def wrapped(message: Message) -> None: message_type = message["type"] if message_type in {"http.response.start", "http.response.body"}: # pragma: no branch message = {**message, "type": "websocket." + message_type} await send(message) return wrapped async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": send = self._wrap_websocket_denial_send(send) await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) await send({"type": "http.response.body", "body": self.body}) if self.background is not None: await self.background() class HTMLResponse(Response): media_type = "text/html" class PlainTextResponse(Response): media_type = "text/plain" class JSONResponse(Response): media_type = "application/json" def __init__( self, content: Any, status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: super().__init__(content, status_code, headers, media_type, background) def render(self, content: Any) -> bytes: return json.dumps( content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), ).encode("utf-8") class RedirectResponse(Response): def __init__( self, url: str | URL, status_code: int = 307, headers: Mapping[str, str] | None = None, background: BackgroundTask | None = None, ) -> None: super().__init__(content=b"", status_code=status_code, headers=headers, background=background) self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") Content = str | bytes | memoryview SyncContentStream = Iterable[Content] AsyncContentStream = AsyncIterable[Content] ContentStream = AsyncContentStream | SyncContentStream class StreamingResponse(Response): body_iterator: AsyncContentStream def __init__( self, content: ContentStream, status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: if isinstance(content, AsyncIterable): self.body_iterator = content else: self.body_iterator = iterate_in_threadpool(content) self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type self.background = background self.init_headers(headers) async def listen_for_disconnect(self, receive: Receive) -> None: while True: message = await receive() if message["type"] == "http.disconnect": break async def stream_response(self, send: Send) -> None: await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) async for chunk in self.body_iterator: if not isinstance(chunk, bytes | memoryview): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": send = self._wrap_websocket_denial_send(send) await self.stream_response(send) if self.background is not None: await self.background() return spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) if spec_version >= (2, 4): try: await self.stream_response(send) except OSError: raise ClientDisconnect() else: with collapse_excgroups(): async with anyio.create_task_group() as task_group: async def wrap(func: Callable[[], Awaitable[None]]) -> None: await func() task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() class MalformedRangeHeader(Exception): def __init__(self, content: str = "Malformed range header.") -> None: self.content = content class RangeNotSatisfiable(Exception): def __init__(self, max_size: int) -> None: self.max_size = max_size class FileResponse(Response): chunk_size = 64 * 1024 def __init__( self, path: str | os.PathLike[str], status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, filename: str | None = None, stat_result: os.stat_result | None = None, content_disposition_type: str = "attachment", ) -> None: self.path = path self.status_code = status_code self.filename = filename if media_type is None: media_type = guess_type(filename or path)[0] or "text/plain" self.media_type = media_type self.background = background self.init_headers(headers) self.headers.setdefault("accept-ranges", "bytes") if self.filename is not None: content_disposition_filename = quote(self.filename) if content_disposition_filename != self.filename: content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" else: content_disposition = f'{content_disposition_type}; filename="{self.filename}"' self.headers.setdefault("content-disposition", content_disposition) self.stat_result = stat_result if stat_result is not None: self.set_stat_headers(stat_result) def set_stat_headers(self, stat_result: os.stat_result) -> None: content_length = str(stat_result.st_size) last_modified = formatdate(stat_result.st_mtime, usegmt=True) etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' self.headers.setdefault("content-length", content_length) self.headers.setdefault("last-modified", last_modified) self.headers.setdefault("etag", etag) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope_type = scope["type"] send_header_only = scope_type == "http" and scope["method"].upper() == "HEAD" send_pathsend = scope_type == "http" and "http.response.pathsend" in scope.get("extensions", {}) if scope_type == "websocket": send = self._wrap_websocket_denial_send(send) if self.stat_result is None: try: stat_result = await anyio.to_thread.run_sync(os.stat, self.path) self.set_stat_headers(stat_result) except FileNotFoundError: raise RuntimeError(f"File at path {self.path} does not exist.") else: mode = stat_result.st_mode if not stat.S_ISREG(mode): raise RuntimeError(f"File at path {self.path} is not a file.") else: stat_result = self.stat_result headers = Headers(scope=scope) http_range = headers.get("range") http_if_range = headers.get("if-range") if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): await self._handle_simple(send, send_header_only, send_pathsend) else: try: ranges = self._parse_range_header(http_range, stat_result.st_size) except MalformedRangeHeader as exc: return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) except RangeNotSatisfiable as exc: response = PlainTextResponse(status_code=416, headers={"Content-Range": f"bytes */{exc.max_size}"}) return await response(scope, receive, send) if len(ranges) == 1: start, end = ranges[0] await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) else: await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) if self.background is not None: await self.background() async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) elif send_pathsend: await send({"type": "http.response.pathsend", "path": str(self.path)}) else: async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) more_body = len(chunk) == self.chunk_size await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) async def _handle_single_range( self, send: Send, start: int, end: int, file_size: int, send_header_only: bool ) -> None: headers = MutableHeaders(raw=list(self.raw_headers)) headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" headers["content-length"] = str(end - start) await send({"type": "http.response.start", "status": 206, "headers": headers.raw}) if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: async with await anyio.open_file(self.path, mode="rb") as file: await file.seek(start) more_body = True while more_body: chunk = await file.read(min(self.chunk_size, end - start)) start += len(chunk) more_body = len(chunk) == self.chunk_size and start < end await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) async def _handle_multiple_ranges( self, send: Send, ranges: list[tuple[int, int]], file_size: int, send_header_only: bool, ) -> None: # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). boundary = token_hex(13) content_length, header_generator = self.generate_multipart( ranges, boundary, file_size, self.headers["content-type"] ) headers = MutableHeaders(raw=list(self.raw_headers)) headers["content-type"] = f"multipart/byteranges; boundary={boundary}" headers["content-length"] = str(content_length) await send({"type": "http.response.start", "status": 206, "headers": headers.raw}) if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: async with await anyio.open_file(self.path, mode="rb") as file: for start, end in ranges: await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) await file.seek(start) while start < end: chunk = await file.read(min(self.chunk_size, end - start)) start += len(chunk) await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"\r\n", "more_body": True}) await send( { "type": "http.response.body", "body": f"--{boundary}--".encode("latin-1"), "more_body": False, } ) def _should_use_range(self, http_if_range: str) -> bool: return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] @classmethod def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]: ranges: list[tuple[int, int]] = [] try: units, range_ = http_range.split("=", 1) except ValueError: raise MalformedRangeHeader() units = units.strip().lower() if units != "bytes": raise MalformedRangeHeader("Only support bytes range") ranges = cls._parse_ranges(range_, file_size) if len(ranges) == 0: raise MalformedRangeHeader("Range header: range must be requested") if any(not (0 <= start < file_size) for start, _ in ranges): raise RangeNotSatisfiable(file_size) if any(start > end for start, end in ranges): raise MalformedRangeHeader("Range header: start must be less than end") if len(ranges) == 1: return ranges # Merge overlapping ranges ranges.sort() result: list[tuple[int, int]] = [ranges[0]] for start, end in ranges[1:]: last_start, last_end = result[-1] if start <= last_end: result[-1] = (last_start, max(last_end, end)) else: result.append((start, end)) return result @classmethod def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]: ranges: list[tuple[int, int]] = [] for part in range_.split(","): part = part.strip() # If the range is empty or a single dash, we ignore it. if not part or part == "-": continue # If the range is not in the format "start-end", we ignore it. if "-" not in part: continue start_str, end_str = part.split("-", 1) start_str = start_str.strip() end_str = end_str.strip() try: start = int(start_str) if start_str else file_size - int(end_str) end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size ranges.append((start, end)) except ValueError: # If the range is not numeric, we ignore it. continue return ranges def generate_multipart( self, ranges: Sequence[tuple[int, int]], boundary: str, max_size: int, content_type: str, ) -> tuple[int, Callable[[int, int], bytes]]: r""" Multipart response headers generator. ``` --{boundary}\r\n Content-Type: {content_type}\r\n Content-Range: bytes {start}-{end-1}/{max_size}\r\n \r\n ..........content...........\r\n --{boundary}\r\n Content-Type: {content_type}\r\n Content-Range: bytes {start}-{end-1}/{max_size}\r\n \r\n ..........content...........\r\n --{boundary}-- ``` """ boundary_len = len(boundary) static_header_part_len = 49 + boundary_len + len(content_type) + len(str(max_size)) content_length = sum( (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers + (end - start) # Content for start, end in ranges ) + ( 4 + boundary_len # --boundary-- ) return ( content_length, lambda start, end: ( f"""\ --{boundary}\r Content-Type: {content_type}\r Content-Range: bytes {start}-{end - 1}/{max_size}\r \r """ ).encode("latin-1"), ) ================================================ FILE: starlette/routing.py ================================================ from __future__ import annotations import contextlib import functools import inspect import re import traceback import types import warnings from collections.abc import Awaitable, Callable, Collection, Generator, Sequence from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager from enum import Enum from re import Pattern from typing import Any, TypeVar from starlette._exception_handler import wrap_app_handling_exceptions from starlette._utils import get_route_path, is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse, Response from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose class NoMatchFound(Exception): """ Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)` if no matching route exists. """ def __init__(self, name: str, path_params: dict[str, Any]) -> None: params = ", ".join(list(path_params.keys())) super().__init__(f'No route exists for name "{name}" and params "{params}".') class Match(Enum): NONE = 0 PARTIAL = 1 FULL = 2 def request_response( func: Callable[[Request], Awaitable[Response] | Response], ) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ f: Callable[[Request], Awaitable[Response]] = ( func if is_async_callable(func) else functools.partial(run_in_threadpool, func) ) async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive, send) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = await f(request) await response(scope, receive, send) await wrap_app_handling_exceptions(app, request)(scope, receive, send) return app def websocket_session( func: Callable[[WebSocket], Awaitable[None]], ) -> ASGIApp: """ Takes a coroutine `func(session)`, and returns an ASGI application. """ # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: session = WebSocket(scope, receive=receive, send=send) async def app(scope: Scope, receive: Receive, send: Send) -> None: await func(session) await wrap_app_handling_exceptions(app, session)(scope, receive, send) return app def get_name(endpoint: Callable[..., Any]) -> str: return getattr(endpoint, "__name__", endpoint.__class__.__name__) def replace_params( path: str, param_convertors: dict[str, Convertor[Any]], path_params: dict[str, str], ) -> tuple[str, dict[str, str]]: for key, value in list(path_params.items()): if "{" + key + "}" in path: convertor = param_convertors[key] value = convertor.to_string(value) path = path.replace("{" + key + "}", value) path_params.pop(key) return path, path_params # Match parameters in URL paths, eg. '{param}', and '{param:int}' PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") def compile_path( path: str, ) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]: """ Given a path string, like: "/{username:str}", or a host string, like: "{subdomain}.mydomain.org", return a three-tuple of (regex, format, {param_name:convertor}). regex: "/(?P[^/]+)" format: "/{username}" convertors: {"username": StringConvertor()} """ is_host = not path.startswith("/") path_regex = "^" path_format = "" duplicated_params: set[str] = set() idx = 0 param_convertors = {} for match in PARAM_REGEX.finditer(path): param_name, convertor_type = match.groups("str") convertor_type = convertor_type.lstrip(":") assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'" convertor = CONVERTOR_TYPES[convertor_type] path_regex += re.escape(path[idx : match.start()]) path_regex += f"(?P<{param_name}>{convertor.regex})" path_format += path[idx : match.start()] path_format += "{%s}" % param_name if param_name in param_convertors: duplicated_params.add(param_name) param_convertors[param_name] = convertor idx = match.end() if duplicated_params: names = ", ".join(sorted(duplicated_params)) ending = "s" if len(duplicated_params) > 1 else "" raise ValueError(f"Duplicated param name{ending} {names} at path {path}") if is_host: # Align with `Host.matches()` behavior, which ignores port. hostname = path[idx:].split(":")[0] path_regex += re.escape(hostname) + "$" else: path_regex += re.escape(path[idx:]) + "$" path_format += path[idx:] return re.compile(path_regex), path_format, param_convertors class BaseRoute: def matches(self, scope: Scope) -> tuple[Match, Scope]: raise NotImplementedError() # pragma: no cover def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: raise NotImplementedError() # pragma: no cover async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: raise NotImplementedError() # pragma: no cover async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ A route may be used in isolation as a stand-alone ASGI app. This is a somewhat contrived case, as they'll almost always be used within a Router, but could be useful for some tooling and minimal apps. """ match, child_scope = self.matches(scope) if match == Match.NONE: if scope["type"] == "http": response = PlainTextResponse("Not Found", status_code=404) await response(scope, receive, send) elif scope["type"] == "websocket": # pragma: no branch websocket_close = WebSocketClose() await websocket_close(scope, receive, send) return scope.update(child_scope) await self.handle(scope, receive, send) class Route(BaseRoute): def __init__( self, path: str, endpoint: Callable[..., Any], *, methods: Collection[str] | None = None, name: str | None = None, include_in_schema: bool = True, middleware: Sequence[Middleware] | None = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name self.include_in_schema = include_in_schema endpoint_handler = endpoint while isinstance(endpoint_handler, functools.partial): endpoint_handler = endpoint_handler.func if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) if methods is None: methods = ["GET"] else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint if middleware is not None: for cls, args, kwargs in reversed(middleware): self.app = cls(self.app, *args, **kwargs) if methods is None: self.methods = None else: self.methods = {method.upper() for method in methods} if "GET" in self.methods: self.methods.add("HEAD") self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> tuple[Match, Scope]: path_params: dict[str, Any] if scope["type"] == "http": route_path = get_route_path(scope) match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = {"endpoint": self.endpoint, "path_params": path_params} if self.methods and scope["method"] not in self.methods: return Match.PARTIAL, child_scope else: return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: seen_params = set(path_params.keys()) expected_params = set(self.param_convertors.keys()) if name != self.name or seen_params != expected_params: raise NoMatchFound(name, path_params) path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) assert not remaining_params return URLPath(path=path, protocol="http") async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: if self.methods and scope["method"] not in self.methods: headers = {"Allow": ", ".join(self.methods)} if "app" in scope: raise HTTPException(status_code=405, headers=headers) else: response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) await response(scope, receive, send) else: await self.app(scope, receive, send) def __eq__(self, other: Any) -> bool: return ( isinstance(other, Route) and self.path == other.path and self.endpoint == other.endpoint and self.methods == other.methods ) def __repr__(self) -> str: class_name = self.__class__.__name__ methods = sorted(self.methods or []) path, name = self.path, self.name return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})" class WebSocketRoute(BaseRoute): def __init__( self, path: str, endpoint: Callable[..., Any], *, name: str | None = None, middleware: Sequence[Middleware] | None = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name endpoint_handler = endpoint while isinstance(endpoint_handler, functools.partial): endpoint_handler = endpoint_handler.func if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(websocket)`. self.app = websocket_session(endpoint) else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint if middleware is not None: for cls, args, kwargs in reversed(middleware): self.app = cls(self.app, *args, **kwargs) self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> tuple[Match, Scope]: path_params: dict[str, Any] if scope["type"] == "websocket": route_path = get_route_path(scope) match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = {"endpoint": self.endpoint, "path_params": path_params} return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: seen_params = set(path_params.keys()) expected_params = set(self.param_convertors.keys()) if name != self.name or seen_params != expected_params: raise NoMatchFound(name, path_params) path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) assert not remaining_params return URLPath(path=path, protocol="websocket") async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: Any) -> bool: return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint def __repr__(self) -> str: return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})" class Mount(BaseRoute): def __init__( self, path: str, app: ASGIApp | None = None, routes: Sequence[BaseRoute] | None = None, name: str | None = None, *, middleware: Sequence[Middleware] | None = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified" self.path = path.rstrip("/") if app is not None: self._base_app: ASGIApp = app else: self._base_app = Router(routes=routes) self.app = self._base_app if middleware is not None: for cls, args, kwargs in reversed(middleware): self.app = cls(self.app, *args, **kwargs) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}") @property def routes(self) -> list[BaseRoute]: return getattr(self._base_app, "routes", []) def matches(self, scope: Scope) -> tuple[Match, Scope]: path_params: dict[str, Any] if scope["type"] in ("http", "websocket"): # pragma: no branch root_path = scope.get("root_path", "") route_path = get_route_path(scope) match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) remaining_path = "/" + matched_params.pop("path") matched_path = route_path[: -len(remaining_path)] path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = { "path_params": path_params, # app_root_path will only be set at the top level scope, # initialized with the (optional) value of a root_path # set above/before Starlette. And even though any # mount will have its own child scope with its own respective # root_path, the app_root_path will always be available in all # the child scopes with the same top level value because it's # set only once here with a default, any other child scope will # just inherit that app_root_path default value stored in the # scope. All this is needed to support Request.url_for(), as it # uses the app_root_path to build the URL path. "app_root_path": scope.get("app_root_path", root_path), "root_path": root_path + matched_path, "endpoint": self.app, } return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path_params["path"] = path_params["path"].lstrip("/") path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) if not remaining_params: return URLPath(path=path) elif self.name is None or name.startswith(self.name + ":"): if self.name is None: # No mount name. remaining_name = name else: # 'name' matches ":". remaining_name = name[len(self.name) + 1 :] path_kwarg = path_params.get("path") path_params["path"] = "" path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) if path_kwarg is not None: remaining_params["path"] = path_kwarg for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol) except NoMatchFound: pass raise NoMatchFound(name, path_params) async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: Any) -> bool: return isinstance(other, Mount) and self.path == other.path and self.app == other.app def __repr__(self) -> str: class_name = self.__class__.__name__ name = self.name or "" return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})" class Host(BaseRoute): def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None: assert not host.startswith("/"), "Host must not start with '/'" self.host = host self.app = app self.name = name self.host_regex, self.host_format, self.param_convertors = compile_path(host) @property def routes(self) -> list[BaseRoute]: return getattr(self.app, "routes", []) def matches(self, scope: Scope) -> tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): # pragma:no branch headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] match = self.host_regex.match(host) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) child_scope = {"path_params": path_params, "endpoint": self.app} return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: if self.name is not None and name == self.name and "path" in path_params: # 'name' matches "". path = path_params.pop("path") host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) if not remaining_params: return URLPath(path=path, host=host) elif self.name is None or name.startswith(self.name + ":"): if self.name is None: # No mount name. remaining_name = name else: # 'name' matches ":". remaining_name = name[len(self.name) + 1 :] host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) for route in self.routes or []: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath(path=str(url), protocol=url.protocol, host=host) except NoMatchFound: pass raise NoMatchFound(name, path_params) async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def __eq__(self, other: Any) -> bool: return isinstance(other, Host) and self.host == other.host and self.app == other.app def __repr__(self) -> str: class_name = self.__class__.__name__ name = self.name or "" return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})" _T = TypeVar("_T") class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]): def __init__(self, cm: AbstractContextManager[_T]): self._cm = cm async def __aenter__(self) -> _T: return self._cm.__enter__() async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> bool | None: return self._cm.__exit__(exc_type, exc_value, traceback) def _wrap_gen_lifespan_context( lifespan_context: Callable[[Any], Generator[Any, Any, Any]], ) -> Callable[[Any], AbstractAsyncContextManager[Any]]: cmgr = contextlib.contextmanager(lifespan_context) @functools.wraps(cmgr) def wrapper(app: Any) -> _AsyncLiftContextManager[Any]: return _AsyncLiftContextManager(cmgr(app)) return wrapper class _DefaultLifespan: def __init__(self, router: Router): self._router = router async def __aenter__(self) -> None: pass async def __aexit__(self, *exc_info: object) -> None: pass def __call__(self: _T, app: object) -> _T: return self class Router: def __init__( self, routes: Sequence[BaseRoute] | None = None, redirect_slashes: bool = True, default: ASGIApp | None = None, # the generic to Lifespan[AppType] is the type of the top level application # which the router cannot know statically, so we use Any lifespan: Lifespan[Any] | None = None, *, middleware: Sequence[Middleware] | None = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes self.default = self.not_found if default is None else default if lifespan is None: self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self) elif inspect.isasyncgenfunction(lifespan): warnings.warn( "async generator function lifespans are deprecated, " "use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) self.lifespan_context = asynccontextmanager(lifespan) elif inspect.isgeneratorfunction(lifespan): warnings.warn( "generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) self.lifespan_context = _wrap_gen_lifespan_context(lifespan) else: self.lifespan_context = lifespan self.middleware_stack = self.app if middleware: for cls, args, kwargs in reversed(middleware): self.middleware_stack = cls(self.middleware_stack, *args, **kwargs) async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": websocket_close = WebSocketClose() await websocket_close(scope, receive, send) return # If we're running inside a starlette application then raise an # exception, so that the configurable exception handler can deal with # returning the response. For plain ASGI apps, just return the response. if "app" in scope: raise HTTPException(status_code=404) else: response = PlainTextResponse("Not Found", status_code=404) await response(scope, receive, send) def url_path_for(self, name: str, /, **path_params: Any) -> URLPath: for route in self.routes: try: return route.url_path_for(name, **path_params) except NoMatchFound: pass raise NoMatchFound(name, path_params) async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ started = False app: Any = scope.get("app") await receive() try: async with self.lifespan_context(app) as maybe_state: if maybe_state is not None: if "state" not in scope: raise RuntimeError('The server does not support "state" in the lifespan scope.') scope["state"].update(maybe_state) await send({"type": "lifespan.startup.complete"}) started = True await receive() except BaseException: exc_text = traceback.format_exc() if started: await send({"type": "lifespan.shutdown.failed", "message": exc_text}) else: await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: await send({"type": "lifespan.shutdown.complete"}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ The main entry point to the Router class. """ await self.middleware_stack(scope, receive, send) async def app(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] in ("http", "websocket", "lifespan") if "router" not in scope: scope["router"] = self if scope["type"] == "lifespan": await self.lifespan(scope, receive, send) return partial = None for route in self.routes: # Determine if any route matches the incoming scope, # and hand over to the matching route if found. match, child_scope = route.matches(scope) if match == Match.FULL: scope.update(child_scope) await route.handle(scope, receive, send) return elif match == Match.PARTIAL and partial is None: partial = route partial_scope = child_scope if partial is not None: #  Handle partial matches. These are cases where an endpoint is # able to handle the request, but is not a preferred option. # We use this in particular to deal with "405 Method Not Allowed". scope.update(partial_scope) await partial.handle(scope, receive, send) return route_path = get_route_path(scope) if scope["type"] == "http" and self.redirect_slashes and route_path != "/": redirect_scope = dict(scope) if route_path.endswith("/"): redirect_scope["path"] = redirect_scope["path"].rstrip("/") else: redirect_scope["path"] = redirect_scope["path"] + "/" for route in self.routes: match, child_scope = route.matches(redirect_scope) if match != Match.NONE: redirect_url = URL(scope=redirect_scope) response = RedirectResponse(url=str(redirect_url)) await response(scope, receive, send) return await self.default(scope, receive, send) def __eq__(self, other: Any) -> bool: return isinstance(other, Router) and self.routes == other.routes def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover route = Mount(path, app=app, name=name) self.routes.append(route) def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover route = Host(host, app=app, name=name) self.routes.append(route) def add_route( self, path: str, endpoint: Callable[[Request], Awaitable[Response] | Response], methods: Collection[str] | None = None, name: str | None = None, include_in_schema: bool = True, ) -> None: # pragma: no cover route = Route( path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema, ) self.routes.append(route) def add_websocket_route( self, path: str, endpoint: Callable[[WebSocket], Awaitable[None]], name: str | None = None, ) -> None: # pragma: no cover route = WebSocketRoute(path, endpoint=endpoint, name=name) self.routes.append(route) ================================================ FILE: starlette/schemas.py ================================================ from __future__ import annotations import inspect import re from collections.abc import Callable from typing import Any, NamedTuple from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Host, Mount, Route try: import yaml except ModuleNotFoundError: # pragma: no cover yaml = None # type: ignore[assignment] class OpenAPIResponse(Response): media_type = "application/vnd.oai.openapi" def render(self, content: Any) -> bytes: assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary." return yaml.dump(content, default_flow_style=False).encode("utf-8") class EndpointInfo(NamedTuple): path: str http_method: str func: Callable[..., Any] _remove_converter_pattern = re.compile(r":\w+}") class BaseSchemaGenerator: def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]: raise NotImplementedError() # pragma: no cover def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: """ Given the routes, yields the following information: - path eg: /users/ - http_method one of 'get', 'post', 'put', 'patch', 'delete', 'options' - func method ready to extract the docstring """ endpoints_info: list[EndpointInfo] = [] for route in routes: if isinstance(route, Mount | Host): routes = route.routes or [] if isinstance(route, Mount): path = self._remove_converter(route.path) else: path = "" sub_endpoints = [ EndpointInfo( path="".join((path, sub_endpoint.path)), http_method=sub_endpoint.http_method, func=sub_endpoint.func, ) for sub_endpoint in self.get_endpoints(routes) ] endpoints_info.extend(sub_endpoints) elif not isinstance(route, Route) or not route.include_in_schema: continue elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): path = self._remove_converter(route.path) for method in route.methods or ["GET"]: if method == "HEAD": continue endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint)) else: path = self._remove_converter(route.path) for method in ["get", "post", "put", "patch", "delete", "options"]: if not hasattr(route.endpoint, method): continue func = getattr(route.endpoint, method) endpoints_info.append(EndpointInfo(path, method.lower(), func)) return endpoints_info def _remove_converter(self, path: str) -> str: """ Remove the converter from the path. For example, a route like this: Route("/users/{id:int}", endpoint=get_user, methods=["GET"]) Should be represented as `/users/{id}` in the OpenAPI schema. """ return _remove_converter_pattern.sub("}", path) def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]: """ Given a function, parse the docstring as YAML and return a dictionary of info. """ docstring = func_or_method.__doc__ if not docstring: return {} assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." # We support having regular docstrings before the schema # definition. Here we return just the schema part from # the docstring. docstring = docstring.split("---")[-1] parsed = yaml.safe_load(docstring) if not isinstance(parsed, dict): # A regular docstring (not yaml formatted) can return # a simple string here, which wouldn't follow the schema. return {} return parsed def OpenAPIResponse(self, request: Request) -> Response: routes = request.app.routes schema = self.get_schema(routes=routes) return OpenAPIResponse(schema) class SchemaGenerator(BaseSchemaGenerator): def __init__(self, base_schema: dict[str, Any]) -> None: self.base_schema = base_schema def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]: schema = dict(self.base_schema) schema.setdefault("paths", {}) endpoints_info = self.get_endpoints(routes) for endpoint in endpoints_info: parsed = self.parse_docstring(endpoint.func) if not parsed: continue if endpoint.path not in schema["paths"]: schema["paths"][endpoint.path] = {} schema["paths"][endpoint.path][endpoint.http_method] = parsed return schema ================================================ FILE: starlette/staticfiles.py ================================================ from __future__ import annotations import errno import importlib.util import os import stat from email.utils import parsedate from typing import Union import anyio import anyio.to_thread from starlette._utils import get_route_path from starlette.datastructures import URL, Headers from starlette.exceptions import HTTPException from starlette.responses import FileResponse, RedirectResponse, Response from starlette.types import Receive, Scope, Send PathLike = Union[str, "os.PathLike[str]"] class NotModifiedResponse(Response): NOT_MODIFIED_HEADERS = ( "cache-control", "content-location", "date", "etag", "expires", "vary", ) def __init__(self, headers: Headers): super().__init__( status_code=304, headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS}, ) class StaticFiles: def __init__( self, *, directory: PathLike | None = None, packages: list[str | tuple[str, str]] | None = None, html: bool = False, check_dir: bool = True, follow_symlink: bool = False, ) -> None: self.directory = directory self.packages = packages self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False self.follow_symlink = follow_symlink if check_dir and directory is not None and not os.path.isdir(directory): raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( self, directory: PathLike | None = None, packages: list[str | tuple[str, str]] | None = None, ) -> list[PathLike]: """ Given `directory` and `packages` arguments, return a list of all the directories that should be used for serving static files from. """ directories = [] if directory is not None: directories.append(directory) for package in packages or []: if isinstance(package, tuple): package, statics_dir = package else: statics_dir = "statics" spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert spec.origin is not None, f"Package {package!r} could not be found." package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir)) assert os.path.isdir(package_directory), ( f"Directory '{statics_dir!r}' in package {package!r} could not be found." ) directories.append(package_directory) return directories async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ The ASGI entry point. """ assert scope["type"] == "http" if not self.config_checked: await self.check_config() self.config_checked = True path = self.get_path(scope) response = await self.get_response(path, scope) await response(scope, receive, send) def get_path(self, scope: Scope) -> str: """ Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ route_path = get_route_path(scope) return os.path.normpath(os.path.join(*route_path.split("/"))) async def get_response(self, path: str, scope: Scope) -> Response: """ Returns an HTTP response, given the incoming path, method and request headers. """ if scope["method"] not in ("GET", "HEAD"): raise HTTPException(status_code=405) try: full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path) except PermissionError: raise HTTPException(status_code=401) except OSError as exc: # Filename is too long, so it can't be a valid static file. if exc.errno == errno.ENAMETOOLONG: raise HTTPException(status_code=404) raise exc except ValueError: # Null bytes or other invalid characters in the path. raise HTTPException(status_code=404) if stat_result and stat.S_ISREG(stat_result.st_mode): # We have a static file to serve. return self.file_response(full_path, stat_result, scope) elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. index_path = os.path.join(path, "index.html") full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path) if stat_result is not None and stat.S_ISREG(stat_result.st_mode): if not scope["path"].endswith("/"): # Directory URLs should redirect to always end in "/". url = URL(scope=scope) url = url.replace(path=url.path + "/") return RedirectResponse(url=url) return self.file_response(full_path, stat_result, scope) if self.html: # Check for '404.html' if we're in HTML mode. full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html") if stat_result and stat.S_ISREG(stat_result.st_mode): return FileResponse(full_path, stat_result=stat_result, status_code=404) raise HTTPException(status_code=404) def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: for directory in self.all_directories: joined_path = os.path.join(directory, path) if self.follow_symlink: full_path = os.path.abspath(joined_path) directory = os.path.abspath(directory) else: full_path = os.path.realpath(joined_path) directory = os.path.realpath(directory) if os.path.commonpath([full_path, directory]) != str(directory): # Don't allow misbehaving clients to break out of the static files directory. continue try: return full_path, os.stat(full_path) except (FileNotFoundError, NotADirectoryError): continue return "", None def file_response( self, full_path: PathLike, stat_result: os.stat_result, scope: Scope, status_code: int = 200, ) -> Response: request_headers = Headers(scope=scope) response = FileResponse(full_path, status_code=status_code, stat_result=stat_result) if self.is_not_modified(response.headers, request_headers): return NotModifiedResponse(response.headers) return response async def check_config(self) -> None: """ Perform a one-off configuration check that StaticFiles is actually pointed at a directory, so that we can raise loud errors rather than just returning 404 responses. """ if self.directory is None: return try: stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.") if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.") def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool: """ Given the request and response headers, return `True` if an HTTP "Not Modified" response could be returned instead. """ if if_none_match := request_headers.get("if-none-match"): # The "etag" header is added by FileResponse, so it's always present. etag = response_headers["etag"] return etag in [tag.strip(" W/") for tag in if_none_match.split(",")] try: if_modified_since = parsedate(request_headers["if-modified-since"]) last_modified = parsedate(response_headers["last-modified"]) if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified: return True except KeyError: pass return False ================================================ FILE: starlette/status.py ================================================ """ HTTP codes See HTTP Status Code Registry: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml And RFC 9110 - https://www.rfc-editor.org/rfc/rfc9110 """ from __future__ import annotations import warnings __all__ = [ "HTTP_100_CONTINUE", "HTTP_101_SWITCHING_PROTOCOLS", "HTTP_102_PROCESSING", "HTTP_103_EARLY_HINTS", "HTTP_200_OK", "HTTP_201_CREATED", "HTTP_202_ACCEPTED", "HTTP_203_NON_AUTHORITATIVE_INFORMATION", "HTTP_204_NO_CONTENT", "HTTP_205_RESET_CONTENT", "HTTP_206_PARTIAL_CONTENT", "HTTP_207_MULTI_STATUS", "HTTP_208_ALREADY_REPORTED", "HTTP_226_IM_USED", "HTTP_300_MULTIPLE_CHOICES", "HTTP_301_MOVED_PERMANENTLY", "HTTP_302_FOUND", "HTTP_303_SEE_OTHER", "HTTP_304_NOT_MODIFIED", "HTTP_305_USE_PROXY", "HTTP_306_RESERVED", "HTTP_307_TEMPORARY_REDIRECT", "HTTP_308_PERMANENT_REDIRECT", "HTTP_400_BAD_REQUEST", "HTTP_401_UNAUTHORIZED", "HTTP_402_PAYMENT_REQUIRED", "HTTP_403_FORBIDDEN", "HTTP_404_NOT_FOUND", "HTTP_405_METHOD_NOT_ALLOWED", "HTTP_406_NOT_ACCEPTABLE", "HTTP_407_PROXY_AUTHENTICATION_REQUIRED", "HTTP_408_REQUEST_TIMEOUT", "HTTP_409_CONFLICT", "HTTP_410_GONE", "HTTP_411_LENGTH_REQUIRED", "HTTP_412_PRECONDITION_FAILED", "HTTP_413_CONTENT_TOO_LARGE", "HTTP_414_URI_TOO_LONG", "HTTP_415_UNSUPPORTED_MEDIA_TYPE", "HTTP_416_RANGE_NOT_SATISFIABLE", "HTTP_417_EXPECTATION_FAILED", "HTTP_418_IM_A_TEAPOT", "HTTP_421_MISDIRECTED_REQUEST", "HTTP_422_UNPROCESSABLE_CONTENT", "HTTP_423_LOCKED", "HTTP_424_FAILED_DEPENDENCY", "HTTP_425_TOO_EARLY", "HTTP_426_UPGRADE_REQUIRED", "HTTP_428_PRECONDITION_REQUIRED", "HTTP_429_TOO_MANY_REQUESTS", "HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE", "HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS", "HTTP_500_INTERNAL_SERVER_ERROR", "HTTP_501_NOT_IMPLEMENTED", "HTTP_502_BAD_GATEWAY", "HTTP_503_SERVICE_UNAVAILABLE", "HTTP_504_GATEWAY_TIMEOUT", "HTTP_505_HTTP_VERSION_NOT_SUPPORTED", "HTTP_506_VARIANT_ALSO_NEGOTIATES", "HTTP_507_INSUFFICIENT_STORAGE", "HTTP_508_LOOP_DETECTED", "HTTP_510_NOT_EXTENDED", "HTTP_511_NETWORK_AUTHENTICATION_REQUIRED", "WS_1000_NORMAL_CLOSURE", "WS_1001_GOING_AWAY", "WS_1002_PROTOCOL_ERROR", "WS_1003_UNSUPPORTED_DATA", "WS_1005_NO_STATUS_RCVD", "WS_1006_ABNORMAL_CLOSURE", "WS_1007_INVALID_FRAME_PAYLOAD_DATA", "WS_1008_POLICY_VIOLATION", "WS_1009_MESSAGE_TOO_BIG", "WS_1010_MANDATORY_EXT", "WS_1011_INTERNAL_ERROR", "WS_1012_SERVICE_RESTART", "WS_1013_TRY_AGAIN_LATER", "WS_1014_BAD_GATEWAY", "WS_1015_TLS_HANDSHAKE", ] HTTP_100_CONTINUE = 100 HTTP_101_SWITCHING_PROTOCOLS = 101 HTTP_102_PROCESSING = 102 HTTP_103_EARLY_HINTS = 103 HTTP_200_OK = 200 HTTP_201_CREATED = 201 HTTP_202_ACCEPTED = 202 HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203 HTTP_204_NO_CONTENT = 204 HTTP_205_RESET_CONTENT = 205 HTTP_206_PARTIAL_CONTENT = 206 HTTP_207_MULTI_STATUS = 207 HTTP_208_ALREADY_REPORTED = 208 HTTP_226_IM_USED = 226 HTTP_300_MULTIPLE_CHOICES = 300 HTTP_301_MOVED_PERMANENTLY = 301 HTTP_302_FOUND = 302 HTTP_303_SEE_OTHER = 303 HTTP_304_NOT_MODIFIED = 304 HTTP_305_USE_PROXY = 305 HTTP_306_RESERVED = 306 HTTP_307_TEMPORARY_REDIRECT = 307 HTTP_308_PERMANENT_REDIRECT = 308 HTTP_400_BAD_REQUEST = 400 HTTP_401_UNAUTHORIZED = 401 HTTP_402_PAYMENT_REQUIRED = 402 HTTP_403_FORBIDDEN = 403 HTTP_404_NOT_FOUND = 404 HTTP_405_METHOD_NOT_ALLOWED = 405 HTTP_406_NOT_ACCEPTABLE = 406 HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407 HTTP_408_REQUEST_TIMEOUT = 408 HTTP_409_CONFLICT = 409 HTTP_410_GONE = 410 HTTP_411_LENGTH_REQUIRED = 411 HTTP_412_PRECONDITION_FAILED = 412 HTTP_413_CONTENT_TOO_LARGE = 413 HTTP_414_URI_TOO_LONG = 414 HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415 HTTP_416_RANGE_NOT_SATISFIABLE = 416 HTTP_417_EXPECTATION_FAILED = 417 HTTP_418_IM_A_TEAPOT = 418 HTTP_421_MISDIRECTED_REQUEST = 421 HTTP_422_UNPROCESSABLE_CONTENT = 422 HTTP_423_LOCKED = 423 HTTP_424_FAILED_DEPENDENCY = 424 HTTP_425_TOO_EARLY = 425 HTTP_426_UPGRADE_REQUIRED = 426 HTTP_428_PRECONDITION_REQUIRED = 428 HTTP_429_TOO_MANY_REQUESTS = 429 HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431 HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451 HTTP_500_INTERNAL_SERVER_ERROR = 500 HTTP_501_NOT_IMPLEMENTED = 501 HTTP_502_BAD_GATEWAY = 502 HTTP_503_SERVICE_UNAVAILABLE = 503 HTTP_504_GATEWAY_TIMEOUT = 504 HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 HTTP_506_VARIANT_ALSO_NEGOTIATES = 506 HTTP_507_INSUFFICIENT_STORAGE = 507 HTTP_508_LOOP_DETECTED = 508 HTTP_510_NOT_EXTENDED = 510 HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 """ WebSocket codes https://www.iana.org/assignments/websocket/websocket.xml#close-code-number https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent """ WS_1000_NORMAL_CLOSURE = 1000 WS_1001_GOING_AWAY = 1001 WS_1002_PROTOCOL_ERROR = 1002 WS_1003_UNSUPPORTED_DATA = 1003 WS_1005_NO_STATUS_RCVD = 1005 WS_1006_ABNORMAL_CLOSURE = 1006 WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007 WS_1008_POLICY_VIOLATION = 1008 WS_1009_MESSAGE_TOO_BIG = 1009 WS_1010_MANDATORY_EXT = 1010 WS_1011_INTERNAL_ERROR = 1011 WS_1012_SERVICE_RESTART = 1012 WS_1013_TRY_AGAIN_LATER = 1013 WS_1014_BAD_GATEWAY = 1014 WS_1015_TLS_HANDSHAKE = 1015 __deprecated__ = { "HTTP_413_REQUEST_ENTITY_TOO_LARGE": 413, "HTTP_414_REQUEST_URI_TOO_LONG": 414, "HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": 416, "HTTP_422_UNPROCESSABLE_ENTITY": 422, } def __getattr__(name: str) -> int: deprecation_changes = { "HTTP_413_REQUEST_ENTITY_TOO_LARGE": "HTTP_413_CONTENT_TOO_LARGE", "HTTP_414_REQUEST_URI_TOO_LONG": "HTTP_414_URI_TOO_LONG", "HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": "HTTP_416_RANGE_NOT_SATISFIABLE", "HTTP_422_UNPROCESSABLE_ENTITY": "HTTP_422_UNPROCESSABLE_CONTENT", } deprecated = __deprecated__.get(name) if deprecated: warnings.warn( f"'{name}' is deprecated. Use '{deprecation_changes[name]}' instead.", category=DeprecationWarning, stacklevel=3, ) return deprecated raise AttributeError(f"module 'starlette.status' has no attribute '{name}'") def __dir__() -> list[str]: return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover ================================================ FILE: starlette/templating.py ================================================ from __future__ import annotations from collections.abc import Callable, Mapping, Sequence from os import PathLike from typing import TYPE_CHECKING, Any, overload from starlette.background import BackgroundTask from starlette.datastructures import URL from starlette.requests import Request from starlette.responses import HTMLResponse from starlette.types import Receive, Scope, Send try: import jinja2 # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1 # hence we try to get pass_context (most installs will be >=3.1) # and fall back to contextfunction, # adding a type ignore for mypy to let us access an attribute that may not exist if TYPE_CHECKING: pass_context = jinja2.pass_context else: if hasattr(jinja2, "pass_context"): pass_context = jinja2.pass_context else: # pragma: no cover pass_context = jinja2.contextfunction # type: ignore[attr-defined] except ImportError as _import_error: # pragma: no cover raise ImportError("jinja2 must be installed to use Jinja2Templates") from _import_error class _TemplateResponse(HTMLResponse): def __init__( self, template: Any, context: dict[str, Any], status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ): self.template = template self.context = context content = template.render(context) super().__init__(content, status_code, headers, media_type, background) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) if "http.response.debug" in extensions: # pragma: no branch await send({"type": "http.response.debug", "info": {"template": self.template, "context": self.context}}) await super().__call__(scope, receive, send) class Jinja2Templates: """Jinja2 template renderer. Example: ```python from starlette.templating import Jinja2Templates templates = Jinja2Templates(directory="templates") async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") ``` """ @overload def __init__( self, directory: str | PathLike[str] | Sequence[str | PathLike[str]], *, context_processors: list[Callable[[Request], dict[str, Any]]] | None = None, ) -> None: ... @overload def __init__( self, *, env: jinja2.Environment, context_processors: list[Callable[[Request], dict[str, Any]]] | None = None, ) -> None: ... def __init__( self, directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None, *, context_processors: list[Callable[[Request], dict[str, Any]]] | None = None, env: jinja2.Environment | None = None, ) -> None: assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed" self.context_processors = context_processors or [] if directory is not None: loader = jinja2.FileSystemLoader(directory) self.env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape()) elif env is not None: # pragma: no branch self.env = env self._setup_env_defaults(self.env) def _setup_env_defaults(self, env: jinja2.Environment) -> None: @pass_context def url_for( context: dict[str, Any], name: str, /, **path_params: Any, ) -> URL: request: Request = context["request"] return request.url_for(name, **path_params) env.globals.setdefault("url_for", url_for) def get_template(self, name: str) -> jinja2.Template: return self.env.get_template(name) def TemplateResponse( self, request: Request, name: str, context: dict[str, Any] | None = None, status_code: int = 200, headers: Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, ) -> _TemplateResponse: """ Render a template and return an HTML response. Args: request: The incoming request instance. name: The template file name to render. context: Variables to pass to the template. status_code: HTTP status code for the response. headers: Additional headers to include in the response. media_type: Media type for the response. background: Background task to run after response is sent. Returns: An HTML response with the rendered template content. """ context = context or {} context.setdefault("request", request) for context_processor in self.context_processors: context.update(context_processor(request)) template = self.get_template(name) return _TemplateResponse( template, context, status_code=status_code, headers=headers, media_type=media_type, background=background, ) ================================================ FILE: starlette/testclient.py ================================================ from __future__ import annotations import contextlib import inspect import io import json import math import sys import warnings from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence from concurrent.futures import Future from contextlib import AbstractContextManager from types import GeneratorType from typing import ( Any, Literal, TypedDict, TypeGuard, cast, ) from urllib.parse import unquote, urljoin import anyio import anyio.abc import anyio.from_thread from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 11): # pragma: no cover from typing import Self else: # pragma: no cover from typing_extensions import Self try: import httpx except ModuleNotFoundError: # pragma: no cover raise RuntimeError( "The starlette.testclient module requires the httpx package to be installed.\n" "You can install this with:\n" " $ pip install httpx\n" ) _PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]] ASGIInstance = Callable[[Receive, Send], Awaitable[None]] ASGI2App = Callable[[Scope], ASGIInstance] ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]] _RequestData = Mapping[str, str | Iterable[str] | bytes] def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: if inspect.isclass(app): return hasattr(app, "__await__") return is_async_callable(app) class _WrapASGI2: """ Provide an ASGI3 interface onto an ASGI2 app. """ def __init__(self, app: ASGI2App) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: instance = self.app(scope) await instance(receive, send) class _AsyncBackend(TypedDict): backend: str backend_options: dict[str, Any] class _Upgrade(Exception): def __init__(self, session: WebSocketTestSession) -> None: self.session = session class WebSocketDenialResponse( # type: ignore[misc] httpx.Response, WebSocketDisconnect, ): """ A special case of `WebSocketDisconnect`, raised in the `TestClient` if the `WebSocket` is closed before being accepted with a `send_denial_response()`. """ class WebSocketTestSession: def __init__( self, app: ASGI3App, scope: Scope, portal_factory: _PortalFactoryType, ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None self.portal_factory = portal_factory self.extra_headers = None def __enter__(self) -> WebSocketTestSession: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context(self.portal_factory()) fut, cs = portal.start_task(self._run) stack.callback(fut.result) stack.callback(portal.call, cs.cancel) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) self.accepted_subprotocol = message.get("subprotocol", None) self.extra_headers = message.get("headers", None) stack.callback(self.close, 1000) self.exit_stack = stack.pop_all() return self def __exit__(self, *args: Any) -> bool | None: return self.exit_stack.__exit__(*args) async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: """ The sub-thread in which the websocket session runs. """ send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) send_tx, send_rx = send receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) receive_tx, receive_rx = receive with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: self._receive_tx = receive_tx self._send_rx = send_rx task_status.started(cs) await self.app(self.scope, receive_rx.receive, send_tx.send) # wait for cs.cancel to be called before closing streams await anyio.sleep_forever() def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) elif message["type"] == "websocket.http.response.start": status_code: int = message["status"] headers: list[tuple[bytes, bytes]] = message["headers"] body: list[bytes] = [] while True: message = self.receive() assert message["type"] == "websocket.http.response.body" body.append(message["body"]) if not message.get("more_body", False): break raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) def send(self, message: Message) -> None: self.portal.call(self._receive_tx.send, message) def send_text(self, data: str) -> None: self.send({"type": "websocket.receive", "text": data}) def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) if mode == "text": self.send({"type": "websocket.receive", "text": text}) else: self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) def close(self, code: int = 1000, reason: str | None = None) -> None: self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) def receive(self) -> Message: return self.portal.call(self._send_rx.receive) def receive_text(self) -> str: message = self.receive() self._raise_on_close(message) return cast(str, message["text"]) def receive_bytes(self) -> bytes: message = self.receive() self._raise_on_close(message) return cast(bytes, message["bytes"]) def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any: message = self.receive() self._raise_on_close(message) if mode == "text": text = message["text"] else: text = message["bytes"].decode("utf-8") return json.loads(text) class _TestClientTransport(httpx.BaseTransport): def __init__( self, app: ASGI3App, portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", *, client: tuple[str, int], app_state: dict[str, Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory self.app_state = app_state self.client = client def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme netloc = request.url.netloc.decode(encoding="ascii") path = request.url.path raw_path = request.url.raw_path query = request.url.query.decode(encoding="ascii") default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] if ":" in netloc: host, port_string = netloc.split(":", 1) port = int(port_string) else: host = netloc port = default_port # Include the 'host' header. if "host" in request.headers: headers: list[tuple[bytes, bytes]] = [] elif port == default_port: # pragma: no cover headers = [(b"host", host.encode())] else: # pragma: no cover headers = [(b"host", (f"{host}:{port}").encode())] # Include other request headers. headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] scope: dict[str, Any] if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: subprotocols: Sequence[str] = [] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { "type": "websocket", "path": unquote(path), "raw_path": raw_path.split(b"?", 1)[0], "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, "client": self.client, "server": [host, port], "subprotocols": subprotocols, "state": self.app_state.copy(), "extensions": {"websocket.http.response": {}}, } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { "type": "http", "http_version": "1.1", "method": request.method, "path": unquote(path), "raw_path": raw_path.split(b"?", 1)[0], "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), "headers": headers, "client": self.client, "server": [host, port], "extensions": {"http.response.debug": {}}, "state": self.app_state.copy(), } request_complete = False response_started = False response_complete: anyio.Event raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()} template = None context = None async def receive() -> Message: nonlocal request_complete if request_complete: if not response_complete.is_set(): await response_complete.wait() return {"type": "http.disconnect"} body = request.read() if isinstance(body, str): body_bytes: bytes = body.encode("utf-8") # pragma: no cover elif body is None: body_bytes = b"" # pragma: no cover elif isinstance(body, GeneratorType): try: # pragma: no cover chunk = body.send(None) if isinstance(chunk, str): chunk = chunk.encode("utf-8") return {"type": "http.request", "body": chunk, "more_body": True} except StopIteration: # pragma: no cover request_complete = True return {"type": "http.request", "body": b""} else: body_bytes = body request_complete = True return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert not response_started, 'Received multiple "http.response.start" messages.' raw_kwargs["status_code"] = message["status"] raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] response_started = True elif message["type"] == "http.response.body": assert response_started, 'Received "http.response.body" without "http.response.start".' assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": raw_kwargs["stream"].write(body) if not more_body: raw_kwargs["stream"].seek(0) response_complete.set() elif message["type"] == "http.response.debug": template = message["info"]["template"] context = message["info"]["context"] try: with self.portal_factory() as portal: response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc if self.raise_server_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: raw_kwargs = { "status_code": 500, "headers": [], "stream": io.BytesIO(), } raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) response = httpx.Response(**raw_kwargs, request=request) if template is not None: response.template = template # type: ignore[attr-defined] response.context = context # type: ignore[attr-defined] return response class TestClient(httpx.Client): __test__ = False task: Future[None] portal: anyio.abc.BlockingPortal | None = None def __init__( self, app: ASGIApp, base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", backend: Literal["asyncio", "trio"] = "asyncio", backend_options: dict[str, Any] | None = None, cookies: httpx._types.CookieTypes | None = None, headers: dict[str, str] | None = None, follow_redirects: bool = True, client: tuple[str, int] = ("testclient", 50000), ) -> None: self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) if _is_asgi3(app): asgi_app = app else: app = cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app self.app_state: dict[str, Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, app_state=self.app_state, client=client, ) if headers is None: headers = {} headers.setdefault("user-agent", "testclient") super().__init__( base_url=base_url, headers=headers, transport=transport, follow_redirects=follow_redirects, cookies=cookies, ) @contextlib.contextmanager def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: if self.portal is not None: yield self.portal else: with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: yield portal def request( # type: ignore[override] self, method: str, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: if timeout is not httpx.USE_CLIENT_DEFAULT: warnings.warn( "You should not use the 'timeout' argument with the TestClient. " "See https://github.com/Kludex/starlette/issues/1108 for more information.", DeprecationWarning, ) url = self._merge_url(url) return super().request( method, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def get( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().get( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def options( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().options( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def head( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().head( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def post( # type: ignore[override] self, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().post( url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def put( # type: ignore[override] self, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().put( url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def patch( # type: ignore[override] self, url: httpx._types.URLTypes, *, content: httpx._types.RequestContent | None = None, data: _RequestData | None = None, files: httpx._types.RequestFiles | None = None, json: Any = None, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().patch( url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def delete( # type: ignore[override] self, url: httpx._types.URLTypes, *, params: httpx._types.QueryParamTypes | None = None, headers: httpx._types.HeaderTypes | None = None, cookies: httpx._types.CookieTypes | None = None, auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, Any] | None = None, ) -> httpx.Response: return super().delete( url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) def websocket_connect( self, url: str, subprotocols: Sequence[str] | None = None, **kwargs: Any, ) -> WebSocketTestSession: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-version", "13") if subprotocols is not None: headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) kwargs["headers"] = headers try: super().request("GET", url, **kwargs) except _Upgrade as exc: session = exc.session else: raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover return session def __enter__(self) -> Self: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) @stack.callback def reset_portal() -> None: self.portal = None send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = ( anyio.create_memory_object_stream(math.inf) ) receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream( math.inf ) for channel in (*send, *receive): stack.callback(channel.close) self.stream_send = StapledObjectStream(*send) self.stream_receive = StapledObjectStream(*receive) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) @stack.callback def wait_shutdown() -> None: portal.call(self.wait_shutdown) self.exit_stack = stack.pop_all() return self def __exit__(self, *args: Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan", "state": self.app_state} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: await self.stream_send.send(None) async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) async def receive() -> Any: message = await self.stream_send.receive() if message is None: self.task.result() return message message = await receive() assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": await receive() async def wait_shutdown(self) -> None: async def receive() -> Any: message = await self.stream_send.receive() if message is None: self.task.result() return message await self.stream_receive.send({"type": "lifespan.shutdown"}) message = await receive() assert message["type"] in ( "lifespan.shutdown.complete", "lifespan.shutdown.failed", ) if message["type"] == "lifespan.shutdown.failed": await receive() ================================================ FILE: starlette/types.py ================================================ from collections.abc import Awaitable, Callable, Mapping, MutableMapping from contextlib import AbstractAsyncContextManager from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket AppType = TypeVar("AppType") Scope = MutableMapping[str, Any] Message = MutableMapping[str, Any] Receive = Callable[[], Awaitable[Message]] Send = Callable[[Message], Awaitable[None]] ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]] StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]] Lifespan = StatelessLifespan[AppType] | StatefulLifespan[AppType] HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"] WebSocketExceptionHandler = Callable[["WebSocket", Exception], Awaitable[None]] ExceptionHandler = HTTPExceptionHandler | WebSocketExceptionHandler ================================================ FILE: starlette/websockets.py ================================================ from __future__ import annotations import enum import json from collections.abc import AsyncIterator, Iterable from typing import Any, cast from starlette.requests import HTTPConnection, StateT from starlette.responses import Response from starlette.types import Message, Receive, Scope, Send class WebSocketState(enum.Enum): CONNECTING = 0 CONNECTED = 1 DISCONNECTED = 2 RESPONSE = 3 class WebSocketDisconnect(Exception): def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.code = code self.reason = reason or "" class WebSocket(HTTPConnection[StateT]): def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: super().__init__(scope) assert scope["type"] == "websocket" self._receive = receive self._send = send self.client_state = WebSocketState.CONNECTING self.application_state = WebSocketState.CONNECTING async def receive(self) -> Message: """ Receive ASGI websocket messages, ensuring valid state transitions. """ if self.client_state == WebSocketState.CONNECTING: message = await self._receive() message_type = message["type"] if message_type != "websocket.connect": raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') self.client_state = WebSocketState.CONNECTED return message elif self.client_state == WebSocketState.CONNECTED: message = await self._receive() message_type = message["type"] if message_type not in {"websocket.receive", "websocket.disconnect"}: raise RuntimeError( f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' ) if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED return message else: raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') async def send(self, message: Message) -> None: """ Send ASGI websocket messages, ensuring valid state transitions. """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: raise RuntimeError( 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' f"but got {message_type!r}" ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED elif message_type == "websocket.http.response.start": self.application_state = WebSocketState.RESPONSE else: self.application_state = WebSocketState.CONNECTED await self._send(message) elif self.application_state == WebSocketState.CONNECTED: message_type = message["type"] if message_type not in {"websocket.send", "websocket.close"}: raise RuntimeError( f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED try: await self._send(message) except OSError: self.application_state = WebSocketState.DISCONNECTED raise WebSocketDisconnect(code=1006) elif self.application_state == WebSocketState.RESPONSE: message_type = message["type"] if message_type != "websocket.http.response.body": raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}') if not message.get("more_body", False): self.application_state = WebSocketState.DISCONNECTED await self._send(message) else: raise RuntimeError('Cannot call "send" once a close message has been sent.') async def accept( self, subprotocol: str | None = None, headers: Iterable[tuple[bytes, bytes]] | None = None, ) -> None: headers = headers or [] if self.client_state == WebSocketState.CONNECTING: # pragma: no branch # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": raise WebSocketDisconnect(message["code"], message.get("reason")) async def receive_text(self) -> str: if self.application_state != WebSocketState.CONNECTED: raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) return cast(str, message["text"]) async def receive_bytes(self) -> bytes: if self.application_state != WebSocketState.CONNECTED: raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) return cast(bytes, message["bytes"]) async def receive_json(self, mode: str = "text") -> Any: if mode not in {"text", "binary"}: raise RuntimeError('The "mode" argument should be "text" or "binary".') if self.application_state != WebSocketState.CONNECTED: raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) if mode == "text": text = message["text"] else: text = message["bytes"].decode("utf-8") return json.loads(text) async def iter_text(self) -> AsyncIterator[str]: try: while True: yield await self.receive_text() except WebSocketDisconnect: pass async def iter_bytes(self) -> AsyncIterator[bytes]: try: while True: yield await self.receive_bytes() except WebSocketDisconnect: pass async def iter_json(self) -> AsyncIterator[Any]: try: while True: yield await self.receive_json() except WebSocketDisconnect: pass async def send_text(self, data: str) -> None: await self.send({"type": "websocket.send", "text": data}) async def send_bytes(self, data: bytes) -> None: await self.send({"type": "websocket.send", "bytes": data}) async def send_json(self, data: Any, mode: str = "text") -> None: if mode not in {"text", "binary"}: raise RuntimeError('The "mode" argument should be "text" or "binary".') text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) if mode == "text": await self.send({"type": "websocket.send", "text": text}) else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) async def close(self, code: int = 1000, reason: str | None = None) -> None: await self.send({"type": "websocket.close", "code": code, "reason": reason or ""}) async def send_denial_response(self, response: Response) -> None: if "websocket.http.response" in self.scope.get("extensions", {}): await response(self.scope, self.receive, self.send) else: raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") class WebSocketClose: def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.code = code self.reason = reason or "" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.close", "code": self.code, "reason": self.reason}) ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/conftest.py ================================================ from __future__ import annotations import functools from typing import Any, Literal import pytest from starlette.testclient import TestClient from tests.types import TestClientFactory @pytest.fixture def test_client_factory( anyio_backend_name: Literal["asyncio", "trio"], anyio_backend_options: dict[str, Any], ) -> TestClientFactory: # anyio_backend_name defined by: # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on return functools.partial( TestClient, backend=anyio_backend_name, backend_options=anyio_backend_options, ) ================================================ FILE: tests/middleware/__init__.py ================================================ ================================================ FILE: tests/middleware/test_base.py ================================================ from __future__ import annotations import contextvars from collections.abc import AsyncGenerator, AsyncIterator, Generator from contextlib import AsyncExitStack from pathlib import Path from typing import Any import anyio import pytest from anyio.abc import TaskStatus from starlette.applications import Starlette from starlette.background import BackgroundTask from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket from tests.types import TestClientFactory class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: response = await call_next(request) response.headers["Custom-Header"] = "Example" return response def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage") def exc(request: Request) -> None: raise Exception("Exc") def exc_stream(request: Request) -> StreamingResponse: return StreamingResponse(_generate_faulty_stream()) def _generate_faulty_stream() -> Generator[bytes, None, None]: yield b"Ok" raise Exception("Faulty Stream") class NoResponse: def __init__( self, scope: Scope, receive: Receive, send: Send, ): pass def __await__(self) -> Generator[Any, None, None]: return self.dispatch().__await__() async def dispatch(self) -> None: pass async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() app = Starlette( routes=[ Route("/", endpoint=homepage), Route("/exc", endpoint=exc), Route("/exc-stream", endpoint=exc_stream), Route("/no-response", endpoint=NoResponse), WebSocketRoute("/ws", endpoint=websocket_endpoint), ], middleware=[Middleware(CustomMiddleware)], ) def test_custom_middleware(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" with pytest.raises(Exception) as ctx: response = client.get("/exc") assert str(ctx.value) == "Exc" with pytest.raises(Exception) as ctx: response = client.get("/exc-stream") assert str(ctx.value) == "Faulty Stream" with pytest.raises(RuntimeError): response = client.get("/no-response") with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" def test_state_data_across_multiple_middlewares( test_client_factory: TestClientFactory, ) -> None: expected_value1 = "foo" expected_value2 = "bar" class aMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: request.state.foo = expected_value1 response = await call_next(request) return response class bMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: request.state.bar = expected_value2 response = await call_next(request) response.headers["X-State-Foo"] = request.state.foo return response class cMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: response = await call_next(request) response.headers["X-State-Bar"] = request.state.bar return response def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK") app = Starlette( routes=[Route("/", homepage)], middleware=[ Middleware(aMiddleware), Middleware(bMiddleware), Middleware(cMiddleware), ], ) client = test_client_factory(app) response = client.get("/") assert response.text == "OK" assert response.headers["X-State-Foo"] == expected_value1 assert response.headers["X-State-Bar"] == expected_value2 def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage") app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]) client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None: # Test for https://github.com/Kludex/starlette/issues/1022 class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> PlainTextResponse: await call_next(request) return PlainTextResponse("Custom") app = Starlette(middleware=[Middleware(CustomMiddleware)]) client = test_client_factory(app) response = client.get("/does_not_exist") assert response.text == "Custom" ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") class CustomMiddlewareWithoutBaseHTTPMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ctxvar.set("set by middleware") await self.app(scope, receive, send) assert ctxvar.get() == "set by endpoint" class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: ctxvar.set("set by middleware") resp = await call_next(request) assert ctxvar.get() == "set by endpoint" return resp # pragma: no cover @pytest.mark.parametrize( "middleware_cls", [ CustomMiddlewareWithoutBaseHTTPMiddleware, pytest.param( CustomMiddlewareUsingBaseHTTPMiddleware, marks=pytest.mark.xfail( reason=( "BaseHTTPMiddleware creates a TaskGroup which copies the context" "and erases any changes to it made within the TaskGroup" ), raises=AssertionError, ), ), ], ) def test_contextvars( test_client_factory: TestClientFactory, middleware_cls: _MiddlewareFactory[Any], ) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) async def homepage(request: Request) -> PlainTextResponse: assert ctxvar.get() == "set by middleware" ctxvar.set("set by endpoint") return PlainTextResponse("Homepage") app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200, response.content @pytest.mark.anyio async def test_run_background_tasks_even_if_client_disconnects() -> None: # test for https://github.com/Kludex/starlette/issues/1438 response_complete = anyio.Event() background_task_run = anyio.Event() async def sleep_and_set() -> None: # small delay to give BaseHTTPMiddleware a chance to cancel us # this is required to make the test fail prior to fixing the issue # so do not be surprised if you remove it and the test still passes await anyio.sleep(0.1) background_task_run.set() async def endpoint_with_background_task(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) async def passthrough( request: Request, call_next: RequestResponseEndpoint, ) -> Response: return await call_next(request) app = Starlette( middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], routes=[Route("/", endpoint_with_background_task)], ) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", } async def receive() -> Message: raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.body": if not message.get("more_body", False): # pragma: no branch response_complete.set() await app(scope, receive, send) assert background_task_run.is_set() def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None: # test for https://github.com/Kludex/starlette/issues/2625 async def sleep_and_set() -> None: await anyio.sleep(0.1) raise ValueError("TEST") async def endpoint_with_background_task(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = Starlette( middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], routes=[Route("/", endpoint_with_background_task)], ) client = test_client_factory(app) with pytest.raises(ValueError, match="TEST"): client.get("/") def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None: async def error_endpoint(_: Request) -> None: raise ValueError("TEST") async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response: try: return await call_next(request) except ValueError as exc: return PlainTextResponse(content=str(exc), status_code=400) app = Starlette( middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)], routes=[Route("/", error_endpoint)], ) client = test_client_factory(app) response = client.get("/") assert response.status_code == 400 assert response.text == "TEST" @pytest.mark.anyio async def test_do_not_block_on_background_tasks() -> None: response_complete = anyio.Event() events: list[str | Message] = [] async def sleep_and_set() -> None: events.append("Background task started") await anyio.sleep(0.1) events.append("Background task finished") async def endpoint_with_background_task(_: Request) -> PlainTextResponse: return PlainTextResponse(content="Hello", background=BackgroundTask(sleep_and_set)) async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = Starlette( middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], routes=[Route("/", endpoint_with_background_task)], ) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", } async def receive() -> Message: raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.body": events.append(message) if not message.get("more_body", False): response_complete.set() async with anyio.create_task_group() as tg: tg.start_soon(app, scope, receive, send) tg.start_soon(app, scope, receive, send) # Without the fix, the background tasks would start and finish before the # last http.response.body is sent. assert events == [ {"body": b"Hello", "more_body": True, "type": "http.response.body"}, {"body": b"", "more_body": False, "type": "http.response.body"}, {"body": b"Hello", "more_body": True, "type": "http.response.body"}, {"body": b"", "more_body": False, "type": "http.response.body"}, "Background task started", "Background task started", "Background task finished", "Background task finished", ] @pytest.mark.anyio async def test_run_context_manager_exit_even_if_client_disconnects() -> None: # test for https://github.com/Kludex/starlette/issues/1678#issuecomment-1172916042 response_complete = anyio.Event() context_manager_exited = anyio.Event() async def sleep_and_set() -> None: # small delay to give BaseHTTPMiddleware a chance to cancel us # this is required to make the test fail prior to fixing the issue # so do not be surprised if you remove it and the test still passes await anyio.sleep(0.1) context_manager_exited.set() class ContextManagerMiddleware: def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with AsyncExitStack() as stack: stack.push_async_callback(sleep_and_set) await self.app(scope, receive, send) async def simple_endpoint(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) async def passthrough( request: Request, call_next: RequestResponseEndpoint, ) -> Response: return await call_next(request) app = Starlette( middleware=[ Middleware(BaseHTTPMiddleware, dispatch=passthrough), Middleware(ContextManagerMiddleware), ], routes=[Route("/", simple_endpoint)], ) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", } async def receive() -> Message: raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.body": if not message.get("more_body", False): # pragma: no branch response_complete.set() await app(scope, receive, send) assert context_manager_exited.is_set() def test_app_receives_http_disconnect_while_sending_if_discarded( test_client_factory: TestClientFactory, ) -> None: class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: Any, ) -> PlainTextResponse: # As a matter of ordering, this test targets the case where the downstream # app response is discarded while it is sending a response body. # We need to wait for the downstream app to begin sending a response body # before sending the middleware response that will overwrite the downstream # response. downstream_app_response = await call_next(request) body_generator = downstream_app_response.body_iterator try: await body_generator.__anext__() finally: await body_generator.aclose() return PlainTextResponse("Custom") async def downstream_app( scope: Scope, receive: Receive, send: Send, ) -> None: await send( { "type": "http.response.start", "status": 200, "headers": [ (b"content-type", b"text/plain"), ], } ) async with anyio.create_task_group() as task_group: async def cancel_on_disconnect( *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: task_status.started() while True: message = await receive() if message["type"] == "http.disconnect": # pragma: no branch task_group.cancel_scope.cancel() break # Using start instead of start_soon to ensure that # cancel_on_disconnect is scheduled by the event loop # before we start returning the body await task_group.start(cancel_on_disconnect) # A timeout is set for 0.1 second in order to ensure that # we never deadlock the test run in an infinite loop with anyio.move_on_after(0.1): while True: await send( { "type": "http.response.body", "body": b"chunk ", "more_body": True, } ) pytest.fail("http.disconnect should have been received and canceled the scope") # pragma: no cover app = DiscardingMiddleware(downstream_app) client = test_client_factory(app) response = client.get("/does_not_exist") assert response.text == "Custom" def test_app_receives_http_disconnect_after_sending_if_discarded( test_client_factory: TestClientFactory, ) -> None: class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> PlainTextResponse: await call_next(request) return PlainTextResponse("Custom") async def downstream_app( scope: Scope, receive: Receive, send: Send, ) -> None: await send( { "type": "http.response.start", "status": 200, "headers": [ (b"content-type", b"text/plain"), ], } ) await send( { "type": "http.response.body", "body": b"first chunk, ", "more_body": True, } ) await send( { "type": "http.response.body", "body": b"second chunk", "more_body": True, } ) message = await receive() assert message["type"] == "http.disconnect" app = DiscardingMiddleware(downstream_app) client = test_client_factory(app) response = client.get("/does_not_exist") assert response.text == "Custom" def test_read_request_stream_in_app_after_middleware_calls_stream( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: expected = [b""] async for chunk in request.stream(): assert chunk == expected.pop(0) assert expected == [] return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) assert expected == [] return await call_next(request) app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 def test_read_request_stream_in_app_after_middleware_calls_body( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) assert expected == [] return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: assert await request.body() == b"a" return await call_next(request) app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 def test_read_request_body_in_app_after_middleware_calls_stream( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) assert expected == [] return await call_next(request) app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 def test_read_request_body_in_app_after_middleware_calls_body( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: assert await request.body() == b"a" return await call_next(request) app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 def test_read_request_stream_in_dispatch_after_app_calls_stream( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) assert expected == [] return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: resp = await call_next(request) with pytest.raises(RuntimeError, match="Stream consumed"): async for _ in request.stream(): raise AssertionError("should not be called") # pragma: no cover return resp app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 def test_read_request_stream_in_dispatch_after_app_calls_body( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: resp = await call_next(request) with pytest.raises(RuntimeError, match="Stream consumed"): async for _ in request.stream(): raise AssertionError("should not be called") # pragma: no cover return resp app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 @pytest.mark.anyio async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None: async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) async for chunk in request.stream(): # pragma: no branch assert chunk == b"2" break await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: expected = b"1" response: Response | None = None async for chunk in request.stream(): # pragma: no branch assert chunk == expected if expected == b"1": response = await call_next(request) expected = b"3" else: break assert response is not None return response async def rcv() -> AsyncGenerator[Message, None]: yield {"type": "http.request", "body": b"1", "more_body": True} yield {"type": "http.request", "body": b"2", "more_body": True} yield {"type": "http.request", "body": b"3"} raise AssertionError( # pragma: no cover "Should not be called, no need to poll for disconnect" ) sent: list[Message] = [] async def send(msg: Message) -> None: sent.append(msg) app: ASGIApp = endpoint app = ConsumingMiddleware(app) rcv_stream = rcv() await app({"type": "http"}, rcv_stream.__anext__, send) assert sent == [ { "type": "http.response.start", "status": 200, "headers": [(b"content-length", b"0")], }, {"type": "http.response.body", "body": b"", "more_body": False}, ] await rcv_stream.aclose() def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: assert await request.body() == b"a" # this buffers the request body in memory resp = await call_next(request) async for chunk in request.stream(): if chunk: assert chunk == b"a" return resp app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( test_client_factory: TestClientFactory, ) -> None: async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: assert await request.body() == b"a" # this buffers the request body in memory resp = await call_next(request) assert await request.body() == b"a" # no problem here return resp app = Starlette( routes=[Route("/", homepage, methods=["POST"])], middleware=[Middleware(ConsumingMiddleware)], ) client: TestClient = test_client_factory(app) response = client.post("/", content=b"a") assert response.status_code == 200 @pytest.mark.anyio async def test_read_request_disconnected_client() -> None: """If we receive a disconnect message when the downstream ASGI app calls receive() the Request instance passed into the dispatch function should get marked as disconnected. The downstream ASGI app should not get a ClientDisconnect raised, instead if should just receive the disconnect message. """ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: msg = await receive() assert msg["type"] == "http.disconnect" await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: response = await call_next(request) disconnected = await request.is_disconnected() assert disconnected is True return response scope = {"type": "http", "method": "POST", "path": "/"} async def receive() -> AsyncGenerator[Message, None]: yield {"type": "http.disconnect"} raise AssertionError("Should not be called, would hang") # pragma: no cover async def send(msg: Message) -> None: if msg["type"] == "http.response.start": assert msg["status"] == 200 app: ASGIApp = ConsumingMiddleware(endpoint) rcv = receive() await app(scope, rcv.__anext__, send) await rcv.aclose() @pytest.mark.anyio async def test_read_request_disconnected_after_consuming_steam() -> None: async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: msg = await receive() assert msg.pop("more_body", False) is False assert msg == {"type": "http.request", "body": b"hi"} msg = await receive() assert msg == {"type": "http.disconnect"} await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: await request.body() disconnected = await request.is_disconnected() assert disconnected is True response = await call_next(request) return response scope = {"type": "http", "method": "POST", "path": "/"} async def receive() -> AsyncGenerator[Message, None]: yield {"type": "http.request", "body": b"hi"} yield {"type": "http.disconnect"} raise AssertionError("Should not be called, would hang") # pragma: no cover async def send(msg: Message) -> None: if msg["type"] == "http.response.start": assert msg["status"] == 200 app: ASGIApp = ConsumingMiddleware(endpoint) rcv = receive() await app(scope, rcv.__anext__, send) await rcv.aclose() def test_downstream_middleware_modifies_receive( test_client_factory: TestClientFactory, ) -> None: """If a downstream middleware modifies receive() the final ASGI app should see the modified version. """ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() assert body == b"foo foo " await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: body = await request.body() assert body == b"foo " return await call_next(request) def modifying_middleware(app: ASGIApp) -> ASGIApp: async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: async def wrapped_receive() -> Message: msg = await receive() if msg["type"] == "http.request": # pragma: no branch msg["body"] = msg["body"] * 2 return msg await app(scope, wrapped_receive, send) return wrapped_app client = test_client_factory(ConsumingMiddleware(modifying_middleware(endpoint))) resp = client.post("/", content=b"foo ") assert resp.status_code == 200 def test_pr_1519_comment_1236166180_example() -> None: """ https://github.com/Kludex/starlette/pull/1519#issuecomment-1236166180 """ bodies: list[bytes] = [] class LogRequestBodySize(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: print(len(await request.body())) return await call_next(request) def replace_body_middleware(app: ASGIApp) -> ASGIApp: async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: async def wrapped_rcv() -> Message: msg = await receive() msg["body"] += b"-foo" return msg await app(scope, wrapped_rcv, send) return wrapped_app async def endpoint(request: Request) -> Response: body = await request.body() bodies.append(body) return Response() app: ASGIApp = Starlette(routes=[Route("/", endpoint, methods=["POST"])]) app = replace_body_middleware(app) app = LogRequestBodySize(app) client = TestClient(app) resp = client.post("/", content=b"Hello, World!") resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] @pytest.mark.anyio async def test_multiple_middlewares_stacked_client_disconnected() -> None: """ Tests for: - https://github.com/Kludex/starlette/issues/2516 - https://github.com/Kludex/starlette/pull/2687 """ ordered_events: list[str] = [] unordered_events: list[str] = [] class MyMiddleware(BaseHTTPMiddleware): def __init__(self, app: ASGIApp, version: int) -> None: self.version = version super().__init__(app) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: ordered_events.append(f"{self.version}:STARTED") res = await call_next(request) ordered_events.append(f"{self.version}:COMPLETED") def background() -> None: unordered_events.append(f"{self.version}:BACKGROUND") assert res.background is None res.background = BackgroundTask(background) return res async def sleepy(request: Request) -> Response: try: await request.body() except ClientDisconnect: pass else: # pragma: no cover raise AssertionError("Should have raised ClientDisconnect") return Response(b"") app = Starlette( routes=[Route("/", sleepy)], middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)], ) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", } async def receive() -> AsyncIterator[Message]: yield {"type": "http.disconnect"} sent: list[Message] = [] async def send(message: Message) -> None: sent.append(message) await app(scope, receive().__anext__, send) assert ordered_events == [ "1:STARTED", "2:STARTED", "3:STARTED", "4:STARTED", "5:STARTED", "6:STARTED", "7:STARTED", "8:STARTED", "9:STARTED", "10:STARTED", "10:COMPLETED", "9:COMPLETED", "8:COMPLETED", "7:COMPLETED", "6:COMPLETED", "5:COMPLETED", "4:COMPLETED", "3:COMPLETED", "2:COMPLETED", "1:COMPLETED", ] assert sorted(unordered_events) == sorted( [ "1:BACKGROUND", "2:BACKGROUND", "3:BACKGROUND", "4:BACKGROUND", "5:BACKGROUND", "6:BACKGROUND", "7:BACKGROUND", "8:BACKGROUND", "9:BACKGROUND", "10:BACKGROUND", ] ) assert sent == [ { "type": "http.response.start", "status": 200, "headers": [(b"content-length", b"0")], }, {"type": "http.response.body", "body": b"", "more_body": False}, ] @pytest.mark.anyio @pytest.mark.parametrize("send_body", [True, False]) async def test_poll_for_disconnect_repeated(send_body: bool) -> None: async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None: for _ in range(2): msg = await receive() while msg["type"] == "http.request": msg = await receive() assert msg["type"] == "http.disconnect" await Response(b"good!")(scope, receive, send) class MyMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = MyMiddleware(app_poll_disconnect) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", } async def receive() -> AsyncIterator[Message]: # the key here is that we only ever send 1 htt.disconnect message if send_body: yield {"type": "http.request", "body": b"hello", "more_body": True} yield {"type": "http.request", "body": b"", "more_body": False} yield {"type": "http.disconnect"} raise AssertionError("Should not be called, would hang") # pragma: no cover sent: list[Message] = [] async def send(message: Message) -> None: sent.append(message) await app(scope, receive().__anext__, send) assert sent == [ { "type": "http.response.start", "status": 200, "headers": [(b"content-length", b"5")], }, {"type": "http.response.body", "body": b"good!", "more_body": True}, {"type": "http.response.body", "body": b"", "more_body": False}, ] @pytest.mark.anyio async def test_asgi_pathsend_events(tmpdir: Path) -> None: path = tmpdir / "example.txt" with path.open("w") as file: file.write("") response_complete = anyio.Event() events: list[Message] = [] async def endpoint_with_pathsend(_: Request) -> FileResponse: return FileResponse(path) async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = Starlette( middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], routes=[Route("/", endpoint_with_pathsend)], ) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", "headers": [], "extensions": {"http.response.pathsend": {}}, } async def receive() -> Message: raise NotImplementedError("Should not be called!") # pragma: no cover async def send(message: Message) -> None: events.append(message) if message["type"] == "http.response.pathsend": response_complete.set() await app(scope, receive, send) assert len(events) == 2 assert events[0]["type"] == "http.response.start" assert events[1]["type"] == "http.response.pathsend" def test_error_context_propagation(test_client_factory: TestClientFactory) -> None: class PassthroughMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint, ) -> Response: return await call_next(request) def exception_without_context(request: Request) -> None: raise Exception("Exception") def exception_with_context(request: Request) -> None: try: raise Exception("Inner exception") except Exception: raise Exception("Outer exception") def exception_with_cause(request: Request) -> None: try: raise Exception("Inner exception") except Exception as e: raise Exception("Outer exception") from e app = Starlette( routes=[ Route("/exception-without-context", endpoint=exception_without_context), Route("/exception-with-context", endpoint=exception_with_context), Route("/exception-with-cause", endpoint=exception_with_cause), ], middleware=[Middleware(PassthroughMiddleware)], ) client = test_client_factory(app) # For exceptions without context the context is filled with the `anyio.EndOfStream` # but it is suppressed therefore not propagated to traceback. with pytest.raises(Exception) as ctx: client.get("/exception-without-context") assert str(ctx.value) == "Exception" assert ctx.value.__cause__ is None assert ctx.value.__context__ is not None assert ctx.value.__suppress_context__ is True # For exceptions with context the context is propagated as a cause to avoid # `anyio.EndOfStream` error from overwriting it. with pytest.raises(Exception) as ctx: client.get("/exception-with-context") assert str(ctx.value) == "Outer exception" assert ctx.value.__cause__ is not None assert str(ctx.value.__cause__) == "Inner exception" # For exceptions with cause check that it gets correctly propagated. with pytest.raises(Exception) as ctx: client.get("/exception-with-cause") assert str(ctx.value) == "Outer exception" assert ctx.value.__cause__ is not None assert str(ctx.value.__cause__) == "Inner exception" ================================================ FILE: tests/middleware/test_cors.py ================================================ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route from tests.types import TestClientFactory def test_cors_allow_all( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"], expose_headers=["X-Status"], allow_credentials=True, ) ], ) client = test_client_factory(app) # Test pre-flight response headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Example", } response = client.options("/", headers=headers) assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-headers"] == "X-Example" assert response.headers["access-control-allow-credentials"] == "true" assert response.headers["vary"] == "Origin" # Test standard response headers = {"Origin": "https://example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-expose-headers"] == "X-Status" assert response.headers["access-control-allow-credentials"] == "true" # Test standard credentialed response headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-expose-headers"] == "X-Status" assert response.headers["access-control-allow-credentials"] == "true" # Test non-CORS response response = client.get("/") assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers def test_cors_allow_all_except_credentials( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"], expose_headers=["X-Status"], ) ], ) client = test_client_factory(app) # Test pre-flight response headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Example", } response = client.options("/", headers=headers) assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-allow-headers"] == "X-Example" assert "access-control-allow-credentials" not in response.headers assert "vary" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-expose-headers"] == "X-Status" assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers def test_cors_allow_specific_origin( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["https://example.org"], allow_headers=["X-Example", "Content-Type"], ) ], ) client = test_client_factory(app) # Test pre-flight response headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Example, Content-Type", } response = client.options("/", headers=headers) assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-headers"] == ( "Accept, Accept-Language, Content-Language, Content-Type, X-Example" ) assert "access-control-allow-credentials" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers def test_cors_disallowed_preflight( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> None: pass # pragma: no cover app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["https://example.org"], allow_headers=["X-Example"], ) ], ) client = test_client_factory(app) # Test pre-flight response headers = { "Origin": "https://another.org", "Access-Control-Request-Method": "POST", "Access-Control-Request-Headers": "X-Nope", } response = client.options("/", headers=headers) assert response.status_code == 400 assert response.text == "Disallowed CORS origin, method, headers" assert "access-control-allow-origin" not in response.headers # Bug specific test, https://github.com/Kludex/starlette/pull/1199 # Test preflight response text with multiple disallowed headers headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Nope-1, X-Nope-2", } response = client.options("/", headers=headers) assert response.text == "Disallowed CORS headers" def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> None: return # pragma: no cover app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST"], allow_credentials=True, ) ], ) client = test_client_factory(app) # Test pre-flight response headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "POST", } response = client.options( "/", headers=headers, ) assert response.status_code == 200 assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" assert response.headers["vary"] == "Origin" def test_cors_preflight_allow_all_methods( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> None: pass # pragma: no cover app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])], ) client = test_client_factory(app) headers = { "Origin": "https://example.org", "Access-Control-Request-Method": "POST", } for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"): response = client.options("/", headers=headers) assert response.status_code == 200 assert method in response.headers["access-control-allow-methods"] def test_cors_allow_all_methods( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[ Route( "/", endpoint=homepage, methods=["delete", "get", "head", "options", "patch", "post", "put"], ) ], middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])], ) client = test_client_factory(app) headers = {"Origin": "https://example.org"} for method in ("patch", "post", "put"): response = getattr(client, method)("/", headers=headers, json={}) assert response.status_code == 200 for method in ("delete", "get", "head", "options"): response = getattr(client, method)("/", headers=headers) assert response.status_code == 200 def test_cors_allow_origin_regex( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_headers=["X-Example", "Content-Type"], allow_origin_regex="https://.*", allow_credentials=True, ) ], ) client = test_client_factory(app) # Test standard response headers = {"Origin": "https://example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" # Test standard credentialed response headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" # Test disallowed standard response # Note that enforcement is a browser concern. The disallowed-ness is reflected # in the lack of an "access-control-allow-origin" header in the response. headers = {"Origin": "http://example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers # Test pre-flight response headers = { "Origin": "https://another.com", "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Example, content-type", } response = client.options("/", headers=headers) assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "https://another.com" assert response.headers["access-control-allow-headers"] == ( "Accept, Accept-Language, Content-Language, Content-Type, X-Example" ) assert response.headers["access-control-allow-credentials"] == "true" # Test disallowed pre-flight response headers = { "Origin": "http://another.com", "Access-Control-Request-Method": "GET", "Access-Control-Request-Headers": "X-Example", } response = client.options("/", headers=headers) assert response.status_code == 400 assert response.text == "Disallowed CORS origin" assert "access-control-allow-origin" not in response.headers def test_cors_allow_origin_regex_fullmatch( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_headers=["X-Example", "Content-Type"], allow_origin_regex=r"https://.*\.example.org", ) ], ) client = test_client_factory(app) # Test standard response headers = {"Origin": "https://subdomain.example.org"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org" assert "access-control-allow-credentials" not in response.headers # Test disallowed standard response headers = {"Origin": "https://subdomain.example.org.hacker.com"} response = client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-origin" not in response.headers def test_cors_vary_header_defaults_to_origin(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], ) headers = {"Origin": "https://example.org"} client = test_client_factory(app) response = client.get("/", headers=headers) assert response.status_code == 200 assert response.headers["vary"] == "Origin" def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["*"])], ) client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding" def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True)], ) client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}) app = Starlette( routes=[ Route("/", endpoint=homepage), ], middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], ) client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://example.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" def test_cors_allowed_origin_does_not_leak_between_requests(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], ) client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://example.org"}) assert response.headers["access-control-allow-origin"] == "https://example.org" response = client.get("/", headers={"Origin": "https://other.org"}) assert "access-control-allow-origin" not in response.headers response = client.get("/", headers={"Origin": "https://example.org"}) assert response.headers["access-control-allow-origin"] == "https://example.org" def test_cors_private_network_access_allowed(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_private_network=True, ) ], ) client = test_client_factory(app) headers_without_pna = {"Origin": "https://example.org", "Access-Control-Request-Method": "GET"} headers_with_pna = {**headers_without_pna, "Access-Control-Request-Private-Network": "true"} # Test preflight with Private Network Access request response = client.options("/", headers=headers_with_pna) assert response.status_code == 200 assert response.text == "OK" assert response.headers["access-control-allow-private-network"] == "true" # Test preflight without Private Network Access request response = client.options("/", headers=headers_without_pna) assert response.status_code == 200 assert response.text == "OK" assert "access-control-allow-private-network" not in response.headers # The access-control-allow-private-network header is not set for non-preflight requests response = client.get("/", headers=headers_with_pna) assert response.status_code == 200 assert response.text == "Homepage" assert "access-control-allow-private-network" not in response.headers assert "access-control-allow-origin" in response.headers def test_cors_private_network_access_disallowed(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> None: ... # pragma: no cover app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_private_network=False, ) ], ) client = test_client_factory(app) # Test preflight with Private Network Access request when not allowed headers_without_pna = {"Origin": "https://example.org", "Access-Control-Request-Method": "GET"} headers_with_pna = {**headers_without_pna, "Access-Control-Request-Private-Network": "true"} response = client.options("/", headers=headers_without_pna) assert response.status_code == 200 assert response.text == "OK" assert "access-control-allow-private-network" not in response.headers # If the request includes a Private Network Access header, but the middleware is configured to disallow it, the # request should be denied with a 400 response. response = client.options("/", headers=headers_with_pna) assert response.status_code == 400 assert response.text == "Disallowed CORS private-network" assert "access-control-allow-private-network" not in response.headers ================================================ FILE: tests/middleware/test_errors.py ================================================ from typing import Any import pytest from starlette.applications import Starlette from starlette.background import BackgroundTask from starlette.middleware.errors import ServerErrorMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route from starlette.types import Receive, Scope, Send from tests.types import TestClientFactory def test_handler( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") def error_500(request: Request, exc: Exception) -> JSONResponse: return JSONResponse({"detail": "Server Error"}, status_code=500) app = ServerErrorMiddleware(app, handler=error_500) client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} def test_debug_text(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.headers["content-type"].startswith("text/plain") assert "RuntimeError: Something went wrong" in response.text def test_debug_html(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 assert response.headers["content-type"].startswith("text/html") assert "RuntimeError" in response.text def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"", status_code=204) await response(scope, receive, send) raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") def test_debug_not_http(test_client_factory: TestClientFactory) -> None: """ DebugMiddleware should just pass through any non-http messages as-is. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app) with pytest.raises(RuntimeError): client = test_client_factory(app) with client.websocket_connect("/"): pass # pragma: no cover def test_background_task(test_client_factory: TestClientFactory) -> None: accessed_error_handler = False def error_handler(request: Request, exc: Exception) -> Any: nonlocal accessed_error_handler accessed_error_handler = True def raise_exception() -> None: raise Exception("Something went wrong") async def endpoint(request: Request) -> Response: task = BackgroundTask(raise_exception) return Response(status_code=204, background=task) app = Starlette( routes=[Route("/", endpoint=endpoint)], exception_handlers={Exception: error_handler}, ) client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 204 assert accessed_error_handler ================================================ FILE: tests/middleware/test_gzip.py ================================================ from __future__ import annotations from pathlib import Path import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request from starlette.responses import ContentStream, FileResponse, PlainTextResponse, StreamingResponse from starlette.routing import Route from starlette.types import Message from tests.types import TestClientFactory def test_gzip_responses(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip" assert response.headers["Vary"] == "Accept-Encoding" assert int(response.headers["Content-Length"]) < 4000 def test_gzip_not_in_accept_encoding(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 assert response.text == "x" * 4000 assert "Content-Encoding" not in response.headers assert response.headers["Vary"] == "Accept-Encoding" assert int(response.headers["Content-Length"]) == 4000 def test_gzip_ignored_for_small_responses( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "OK" assert "Content-Encoding" not in response.headers assert "Vary" not in response.headers assert int(response.headers["Content-Length"]) == 2 def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): yield bytes streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip" assert response.headers["Vary"] == "Accept-Encoding" assert "Content-Length" not in response.headers def test_gzip_streaming_response_identity(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): yield bytes streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 assert response.text == "x" * 4000 assert "Content-Encoding" not in response.headers assert response.headers["Vary"] == "Accept-Encoding" assert "Content-Length" not in response.headers def test_gzip_ignored_for_responses_with_encoding_set( test_client_factory: TestClientFactory, ) -> None: def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): yield bytes streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"}) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip, text"}) assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Vary" not in response.headers assert "Content-Length" not in response.headers def test_gzip_ignored_on_server_sent_events(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for _ in range(count): yield bytes streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200, media_type="text/event-stream") app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)], ) client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 assert "Content-Encoding" not in response.headers assert "Content-Length" not in response.headers @pytest.mark.anyio async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: path = tmpdir / "example.txt" with path.open("w") as file: file.write("") events: list[Message] = [] async def endpoint_with_pathsend(request: Request) -> FileResponse: _ = await request.body() return FileResponse(path) app = Starlette( routes=[Route("/", endpoint=endpoint_with_pathsend)], middleware=[Middleware(GZipMiddleware)], ) scope = { "type": "http", "version": "3", "method": "GET", "path": "/", "headers": [(b"accept-encoding", b"gzip, text")], "extensions": {"http.response.pathsend": {}}, } async def receive() -> Message: return {"type": "http.request", "body": b"", "more_body": False} async def send(message: Message) -> None: events.append(message) await app(scope, receive, send) assert len(events) == 2 assert events[0]["type"] == "http.response.start" assert events[1]["type"] == "http.response.pathsend" ================================================ FILE: tests/middleware/test_https_redirect.py ================================================ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route from tests.types import TestClientFactory def test_https_redirect_middleware(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(HTTPSRedirectMiddleware)], ) client = test_client_factory(app, base_url="https://testserver") response = client.get("/") assert response.status_code == 200 client = test_client_factory(app) response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:80") response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:443") response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:123") response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" ================================================ FILE: tests/middleware/test_middleware.py ================================================ from starlette.middleware import Middleware from starlette.types import ASGIApp, Receive, Scope, Send class CustomMiddleware: # pragma: no cover def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None: self.app = app self.foo = foo self.bar = bar async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def test_middleware_repr() -> None: middleware = Middleware(CustomMiddleware, "foo", bar=123) assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)" def test_middleware_iter() -> None: cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123) assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123}) ================================================ FILE: tests/middleware/test_session.py ================================================ import re from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.sessions import Session, SessionMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount, Route from starlette.testclient import TestClient from tests.types import TestClientFactory def view_session(request: Request) -> JSONResponse: return JSONResponse({"session": request.session}) async def update_session(request: Request) -> JSONResponse: data = await request.json() request.session.update(data) return JSONResponse({"session": request.session}) async def clear_session(request: Request) -> JSONResponse: request.session.clear() return JSONResponse({"session": request.session}) def no_session_access(request: Request) -> JSONResponse: return JSONResponse({"status": "ok"}) def test_session(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), Route("/clear_session", endpoint=clear_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) client = test_client_factory(app) response = client.get("/view_session") assert response.json() == {"session": {}} response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # check cookie max-age set_cookie = response.headers["set-cookie"] max_age_matches = re.search(r"; Max-Age=([0-9]+);", set_cookie) assert max_age_matches is not None assert int(max_age_matches[1]) == 14 * 24 * 3600 response = client.get("/view_session") assert response.json() == {"session": {"some": "data"}} response = client.post("/clear_session") assert response.json() == {"session": {}} response = client.get("/view_session") assert response.json() == {"session": {}} def test_session_expires(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=-1)], ) client = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # requests removes expired cookies from response.cookies, we need to # fetch session id from the headers and pass it explicitly expired_cookie_header = response.headers["set-cookie"] expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header) assert expired_session_match is not None expired_session_value = expired_session_match[1] client = test_client_factory(app, cookies={"session": expired_session_value}) response = client.get("/view_session") assert response.json() == {"session": {}} def test_secure_session(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), Route("/clear_session", endpoint=clear_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)], ) secure_client = test_client_factory(app, base_url="https://testserver") unsecure_client = test_client_factory(app, base_url="http://testserver") response = unsecure_client.get("/view_session") assert response.json() == {"session": {}} response = unsecure_client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} response = unsecure_client.get("/view_session") assert response.json() == {"session": {}} response = secure_client.get("/view_session") assert response.json() == {"session": {}} response = secure_client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} response = secure_client.get("/view_session") assert response.json() == {"session": {"some": "data"}} response = secure_client.post("/clear_session") assert response.json() == {"session": {}} response = secure_client.get("/view_session") assert response.json() == {"session": {}} def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None: second_app = Starlette( routes=[ Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")], ) app = Starlette(routes=[Mount("/second_app", app=second_app)]) client = test_client_factory(app, base_url="http://testserver/second_app") response = client.post("/update_session", json={"some": "data"}) assert response.status_code == 200 cookie = response.headers["set-cookie"] cookie_path_match = re.search(r"; path=(\S+);", cookie) assert cookie_path_match is not None cookie_path = cookie_path_match.groups()[0] assert cookie_path == "/second_app" def test_invalid_session_cookie(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) client = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # we expect it to not raise an exception if we provide a bogus session cookie client = test_client_factory(app, cookies={"session": "invalid"}) response = client.get("/view_session") assert response.json() == {"session": {}} def test_session_cookie(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)], ) client: TestClient = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # check cookie max-age set_cookie = response.headers["set-cookie"] assert "Max-Age" not in set_cookie client.cookies.delete("session") response = client.get("/view_session") assert response.json() == {"session": {}} def test_domain_cookie(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")], ) client: TestClient = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} # check cookie max-age set_cookie = response.headers["set-cookie"] assert "domain=.example.com" in set_cookie client.cookies.delete("session") response = client.get("/view_session") assert response.json() == {"session": {}} def test_set_cookie_only_on_modification(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) client = test_client_factory(app) # Write to session - should send Set-Cookie response = client.post("/update_session", json={"some": "data"}) assert "set-cookie" in response.headers # Read-only access - should NOT send Set-Cookie response = client.get("/view_session") assert response.json() == {"session": {"some": "data"}} assert "set-cookie" not in response.headers def test_vary_cookie_on_access(test_client_factory: TestClientFactory) -> None: app = Starlette( routes=[ Route("/view_session", endpoint=view_session), Route("/update_session", endpoint=update_session, methods=["POST"]), Route("/no_session", endpoint=no_session_access), ], middleware=[Middleware(SessionMiddleware, secret_key="example")], ) client = test_client_factory(app) # Modifying session should add Vary: Cookie response = client.post("/update_session", json={"some": "data"}) assert "cookie" in response.headers.get("vary", "").lower() # Reading a non-empty session should add Vary: Cookie response = client.get("/view_session") assert "cookie" in response.headers.get("vary", "").lower() # Not accessing session at all should NOT add Vary: Cookie response = client.get("/no_session") assert "cookie" not in response.headers.get("vary", "").lower() def test_session_tracks_modification() -> None: session = Session({"a": "1", "b": "2"}) assert not session.modified # __setitem__ session["c"] = "3" assert session.modified # __delitem__ session = Session({"a": "1"}) del session["a"] assert session.modified # clear session = Session({"a": "1"}) session.clear() assert session.modified # pop with existing key session = Session({"a": "1"}) session.pop("a") assert session.modified # pop with missing key session = Session({"a": "1"}) session.pop("missing", None) assert not session.modified # setdefault with missing key session = Session({"a": "1"}) session.setdefault("b", "2") assert session.modified # setdefault with existing key session = Session({"a": "1"}) session.setdefault("a", "2") assert not session.modified ================================================ FILE: tests/middleware/test_trusted_host.py ================================================ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route from tests.types import TestClientFactory def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])], ) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 client = test_client_factory(app, base_url="http://subdomain.testserver") response = client.get("/") assert response.status_code == 200 client = test_client_factory(app, base_url="http://invalidhost") response = client.get("/") assert response.status_code == 400 def test_default_allowed_hosts() -> None: app = Starlette() middleware = TrustedHostMiddleware(app) assert middleware.allowed_hosts == ["*"] def test_www_redirect(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) app = Starlette( routes=[Route("/", endpoint=homepage)], middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])], ) client = test_client_factory(app, base_url="https://example.com") response = client.get("/") assert response.status_code == 200 assert response.url == "https://www.example.com/" ================================================ FILE: tests/middleware/test_wsgi.py ================================================ import sys from collections.abc import Callable, Iterable from typing import Any import pytest from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ from tests.types import TestClientFactory WSGIResponse = Iterable[bytes] StartResponse = Callable[..., Any] Environment = dict[str, Any] def hello_world( environ: Environment, start_response: StartResponse, ) -> WSGIResponse: status = "200 OK" output = b"Hello World!\n" headers = [ ("Content-Type", "text/plain; charset=utf-8"), ("Content-Length", str(len(output))), ] start_response(status, headers) return [output] def echo_body( environ: Environment, start_response: StartResponse, ) -> WSGIResponse: status = "200 OK" output = environ["wsgi.input"].read() headers = [ ("Content-Type", "text/plain; charset=utf-8"), ("Content-Length", str(len(output))), ] start_response(status, headers) return [output] def raise_exception( environ: Environment, start_response: StartResponse, ) -> WSGIResponse: raise RuntimeError("Something went wrong") def return_exc_info( environ: Environment, start_response: StartResponse, ) -> WSGIResponse: try: raise RuntimeError("Something went wrong") except RuntimeError: status = "500 Internal Server Error" output = b"Internal Server Error" headers = [ ("Content-Type", "text/plain; charset=utf-8"), ("Content-Length", str(len(output))), ] start_response(status, headers, exc_info=sys.exc_info()) return [output] def test_wsgi_get(test_client_factory: TestClientFactory) -> None: app = WSGIMiddleware(hello_world) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello World!\n" def test_wsgi_post(test_client_factory: TestClientFactory) -> None: app = WSGIMiddleware(echo_body) client = test_client_factory(app) response = client.post("/", json={"example": 123}) assert response.status_code == 200 assert response.text == '{"example":123}' def test_wsgi_exception(test_client_factory: TestClientFactory) -> None: # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) with pytest.raises(RuntimeError), collapse_excgroups(): client.get("/") def test_wsgi_exc_info(test_client_factory: TestClientFactory) -> None: # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(return_exc_info) client = test_client_factory(app) with pytest.raises(RuntimeError): response = client.get("/") app = WSGIMiddleware(return_exc_info) client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.text == "Internal Server Error" def test_build_environ() -> None: scope = { "type": "http", "http_version": "1.1", "method": "GET", "scheme": "https", "path": "/sub/", "root_path": "/sub", "query_string": b"a=123&b=456", "headers": [ (b"host", b"www.example.org"), (b"content-type", b"application/json"), (b"content-length", b"18"), (b"accept", b"application/json"), (b"accept", b"text/plain"), ], "client": ("134.56.78.4", 1453), "server": ("www.example.org", 443), } body = b'{"example":"body"}' environ = build_environ(scope, body) stream = environ.pop("wsgi.input") assert stream.read() == b'{"example":"body"}' assert environ == { "CONTENT_LENGTH": "18", "CONTENT_TYPE": "application/json", "HTTP_ACCEPT": "application/json,text/plain", "HTTP_HOST": "www.example.org", "PATH_INFO": "/", "QUERY_STRING": "a=123&b=456", "REMOTE_ADDR": "134.56.78.4", "REQUEST_METHOD": "GET", "SCRIPT_NAME": "/sub", "SERVER_NAME": "www.example.org", "SERVER_PORT": 443, "SERVER_PROTOCOL": "HTTP/1.1", "wsgi.errors": sys.stdout, "wsgi.multiprocess": True, "wsgi.multithread": True, "wsgi.run_once": False, "wsgi.url_scheme": "https", "wsgi.version": (1, 0), } def test_build_environ_encoding() -> None: scope = { "type": "http", "http_version": "1.1", "method": "GET", "path": "/小星", "root_path": "/中国", "query_string": b"a=123&b=456", "headers": [], } environ = build_environ(scope, b"") assert environ["SCRIPT_NAME"] == "/中国".encode().decode("latin-1") assert environ["PATH_INFO"] == "/小星".encode().decode("latin-1") ================================================ FILE: tests/statics/example.txt ================================================ 123 ================================================ FILE: tests/test__utils.py ================================================ import functools from typing import Any from unittest.mock import create_autospec import pytest from starlette._utils import get_route_path, is_async_callable from starlette.types import Scope def test_async_func() -> None: async def async_func() -> None: ... # pragma: no cover def func() -> None: ... # pragma: no cover assert is_async_callable(async_func) assert not is_async_callable(func) def test_async_partial() -> None: async def async_func(a: Any, b: Any) -> None: ... # pragma: no cover def func(a: Any, b: Any) -> None: ... # pragma: no cover partial = functools.partial(async_func, 1) assert is_async_callable(partial) partial = functools.partial(func, 1) # type: ignore assert not is_async_callable(partial) def test_async_method() -> None: class Async: async def method(self) -> None: ... # pragma: no cover class Sync: def method(self) -> None: ... # pragma: no cover assert is_async_callable(Async().method) assert not is_async_callable(Sync().method) def test_async_object_call() -> None: class Async: async def __call__(self) -> None: ... # pragma: no cover class Sync: def __call__(self) -> None: ... # pragma: no cover assert is_async_callable(Async()) assert not is_async_callable(Sync()) def test_async_partial_object_call() -> None: class Async: async def __call__( self, a: Any, b: Any, ) -> None: ... # pragma: no cover class Sync: def __call__( self, a: Any, b: Any, ) -> None: ... # pragma: no cover partial = functools.partial(Async(), 1) assert is_async_callable(partial) partial = functools.partial(Sync(), 1) # type: ignore assert not is_async_callable(partial) def test_async_nested_partial() -> None: async def async_func( a: Any, b: Any, ) -> None: ... # pragma: no cover partial = functools.partial(async_func, b=2) nested_partial = functools.partial(partial, a=1) assert is_async_callable(nested_partial) def test_async_mocked_async_function() -> None: async def async_func() -> None: ... # pragma: no cover mock = create_autospec(async_func) assert is_async_callable(mock) @pytest.mark.parametrize( "scope, expected_result", [ ({"path": "/foo-123/bar", "root_path": "/foo"}, "/foo-123/bar"), ({"path": "/foo/bar", "root_path": "/foo"}, "/bar"), ({"path": "/foo", "root_path": "/foo"}, ""), ({"path": "/foo/bar", "root_path": "/bar"}, "/foo/bar"), ], ) def test_get_route_path(scope: Scope, expected_result: str) -> None: assert get_route_path(scope) == expected_result ================================================ FILE: tests/test_applications.py ================================================ from __future__ import annotations import os from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator from contextlib import asynccontextmanager from pathlib import Path from typing import TypedDict import anyio.from_thread import pytest from starlette import status from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient, WebSocketDenialResponse from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket from tests.types import TestClientFactory async def error_500(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": "Server Error"}, status_code=500) async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": "Custom message"}, status_code=405) async def http_exception(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) def func_homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") async def async_homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") class Homepage(HTTPEndpoint): def get(self, request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") def all_users_page(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, everyone!") def user_page(request: Request) -> PlainTextResponse: username = request.path_params["username"] return PlainTextResponse(f"Hello, {username}!") def custom_subdomain(request: Request) -> PlainTextResponse: return PlainTextResponse("Subdomain: " + request.path_params["subdomain"]) def runtime_error(request: Request) -> None: raise RuntimeError() async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() async def websocket_raise_websocket_exception(websocket: WebSocket) -> None: await websocket.accept() raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) async def websocket_raise_http_exception(websocket: WebSocket) -> None: raise HTTPException(status_code=401, detail="Unauthorized") class CustomWSException(Exception): pass async def websocket_raise_custom(websocket: WebSocket) -> None: await websocket.accept() raise CustomWSException() async def websocket_state(websocket: WebSocket[CustomState]) -> None: await websocket.accept() await websocket.send_json({"count": websocket.state["count"]}) await websocket.close() def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None: anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER) class CustomState(TypedDict): count: int @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[CustomState]: yield {"count": 1} async def state_count(request: Request[CustomState]) -> JSONResponse: return JSONResponse({"count": request.state["count"]}, status_code=200) users = Router( routes=[ Route("/", endpoint=all_users_page), Route("/{username}", endpoint=user_page), ] ) subdomain = Router( routes=[ Route("/", custom_subdomain), ] ) exception_handlers = { 500: error_500, 405: method_not_allowed, HTTPException: http_exception, CustomWSException: custom_ws_exception_handler, } middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])] app = Starlette( routes=[ Route("/func", endpoint=func_homepage), Route("/async", endpoint=async_homepage), Route("/class", endpoint=Homepage), Route("/state", endpoint=state_count), Route("/500", endpoint=runtime_error), WebSocketRoute("/ws", endpoint=websocket_endpoint), WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception), WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception), WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), WebSocketRoute("/ws-state", endpoint=websocket_state), Mount("/users", app=users), Host("{subdomain}.example.org", app=subdomain), ], exception_handlers=exception_handlers, # type: ignore middleware=middleware, lifespan=lifespan, ) @pytest.fixture def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client def test_url_path_for() -> None: assert app.url_path_for("func_homepage") == "/func" def test_func_route(client: TestClient) -> None: response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" response = client.head("/func") assert response.status_code == 200 assert response.text == "" def test_async_route(client: TestClient) -> None: response = client.get("/async") assert response.status_code == 200 assert response.text == "Hello, world!" def test_class_route(client: TestClient) -> None: response = client.get("/class") assert response.status_code == 200 assert response.text == "Hello, world!" def test_mounted_route(client: TestClient) -> None: response = client.get("/users/") assert response.status_code == 200 assert response.text == "Hello, everyone!" def test_mounted_route_path_params(client: TestClient) -> None: response = client.get("/users/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" def test_subdomain_route(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 assert response.text == "Subdomain: foo" def test_websocket_route(client: TestClient) -> None: with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" def test_400(client: TestClient) -> None: response = client.get("/404") assert response.status_code == 404 assert response.json() == {"detail": "Not Found"} def test_405(client: TestClient) -> None: response = client.post("/func") assert response.status_code == 405 assert response.json() == {"detail": "Custom message"} response = client.post("/class") assert response.status_code == 405 assert response.json() == {"detail": "Custom message"} def test_500(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/500") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} def test_request_state(client: TestClient) -> None: response = client.get("/state") assert response.status_code == 200 assert response.json() == {"count": 1} def test_websocket_raise_websocket_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-websocket") as session: response = session.receive() assert response == { "type": "websocket.close", "code": status.WS_1003_UNSUPPORTED_DATA, "reason": "", } def test_websocket_state(client: TestClient) -> None: with client.websocket_connect("/ws-state") as session: response = session.receive_json() assert response == {"count": 1} def test_websocket_raise_http_exception(client: TestClient) -> None: with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/ws-raise-http"): pass # pragma: no cover assert exc.value.status_code == 401 assert exc.value.content == b'{"detail":"Unauthorized"}' def test_websocket_raise_custom_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-custom") as session: response = session.receive() assert response == { "type": "websocket.close", "code": status.WS_1013_TRY_AGAIN_LATER, "reason": "", } def test_middleware(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, base_url="http://incorrecthost") response = client.get("/func") assert response.status_code == 400 assert response.text == "Invalid host header" def test_routes() -> None: assert app.routes == [ Route("/func", endpoint=func_homepage, methods=["GET"]), Route("/async", endpoint=async_homepage, methods=["GET"]), Route("/class", endpoint=Homepage), Route("/state", endpoint=state_count, methods=["GET"]), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception), WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception), WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), WebSocketRoute("/ws-state", endpoint=websocket_state), Mount( "/users", app=Router( routes=[ Route("/", endpoint=all_users_page), Route("/{username}", endpoint=user_page), ] ), ), Host( "{subdomain}.example.org", app=Router(routes=[Route("/", endpoint=custom_subdomain)]), ), ] def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = Starlette( routes=[ Mount("/static", StaticFiles(directory=tmpdir)), ] ) client = test_client_factory(app) response = client.get("/static/example.txt") assert response.status_code == 200 assert response.text == "" response = client.post("/static/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" def test_app_debug(test_client_factory: TestClientFactory) -> None: async def homepage(request: Request) -> None: raise RuntimeError() app = Starlette( routes=[ Route("/", homepage), ], ) app.debug = True client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert "RuntimeError" in response.text assert app.debug def test_app_add_route(test_client_factory: TestClientFactory) -> None: async def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, World!") app = Starlette( routes=[ Route("/", endpoint=homepage), ] ) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None: async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() app = Starlette( routes=[ WebSocketRoute("/ws", endpoint=websocket_endpoint), ] ) client = test_client_factory(app) with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False @asynccontextmanager async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: nonlocal startup_complete, cleanup_complete startup_complete = True yield cleanup_complete = True app = Starlette(lifespan=lifespan) assert not startup_complete assert not cleanup_complete with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete deprecated_lifespan = pytest.mark.filterwarnings( r"ignore" r":(async )?generator function lifespans are deprecated, use an " r"@contextlib\.asynccontextmanager function instead" r":DeprecationWarning" r":starlette.routing" ) @deprecated_lifespan def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: nonlocal startup_complete, cleanup_complete startup_complete = True yield cleanup_complete = True app = Starlette(lifespan=lifespan) # type: ignore assert not startup_complete assert not cleanup_complete with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete @deprecated_lifespan def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False def lifespan(app: ASGIApp) -> Generator[None, None, None]: nonlocal startup_complete, cleanup_complete startup_complete = True yield cleanup_complete = True app = Starlette(lifespan=lifespan) # type: ignore assert not startup_complete assert not cleanup_complete with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None: class NoOpMiddleware: def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) class SimpleInitializableMiddleware: counter = 0 def __init__(self, app: ASGIApp): self.app = app SimpleInitializableMiddleware.counter += 1 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) def get_app() -> ASGIApp: app = Starlette() app.add_middleware(SimpleInitializableMiddleware) app.add_middleware(NoOpMiddleware) return app app = get_app() with test_client_factory(app): pass assert SimpleInitializableMiddleware.counter == 1 test_client_factory(app).get("/foo") assert SimpleInitializableMiddleware.counter == 1 app = get_app() test_client_factory(app).get("/foo") assert SimpleInitializableMiddleware.counter == 2 def test_middleware_args(test_client_factory: TestClientFactory) -> None: calls: list[str] = [] class MiddlewareWithArgs: def __init__(self, app: ASGIApp, arg: str) -> None: self.app = app self.arg = arg async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: calls.append(self.arg) await self.app(scope, receive, send) app = Starlette() app.add_middleware(MiddlewareWithArgs, "foo") app.add_middleware(MiddlewareWithArgs, "bar") with test_client_factory(app): pass assert calls == ["bar", "foo"] def test_middleware_factory(test_client_factory: TestClientFactory) -> None: calls: list[str] = [] def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp: async def _app(scope: Scope, receive: Receive, send: Send) -> None: calls.append(arg) await app(scope, receive, send) return _app def get_middleware_factory() -> Callable[[ASGIApp, str], ASGIApp]: return _middleware_factory app = Starlette() app.add_middleware(_middleware_factory, arg="foo") app.add_middleware(get_middleware_factory(), "bar") with test_client_factory(app): pass assert calls == ["bar", "foo"] def test_lifespan_app_subclass() -> None: # This test exists to make sure that subclasses of Starlette # (like FastAPI) are compatible with the types hints for Lifespan class App(Starlette): pass @asynccontextmanager async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover yield App(lifespan=lifespan) ================================================ FILE: tests/test_authentication.py ================================================ from __future__ import annotations import base64 import binascii from collections.abc import Awaitable, Callable from typing import Any from urllib.parse import urlencode import pytest from starlette.applications import Starlette from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser, requires from starlette.endpoints import HTTPEndpoint from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import HTTPConnection, Request from starlette.responses import JSONResponse, Response from starlette.routing import Route, WebSocketRoute from starlette.websockets import WebSocket, WebSocketDisconnect from tests.types import TestClientFactory AsyncEndpoint = Callable[..., Awaitable[Response]] SyncEndpoint = Callable[..., Response] class BasicAuth(AuthenticationBackend): async def authenticate( self, request: HTTPConnection, ) -> tuple[AuthCredentials, SimpleUser] | None: if "Authorization" not in request.headers: return None auth = request.headers["Authorization"] try: scheme, credentials = auth.split() decoded = base64.b64decode(credentials).decode("ascii") except (ValueError, UnicodeDecodeError, binascii.Error): raise AuthenticationError("Invalid basic auth credentials") username, _, password = decoded.partition(":") return AuthCredentials(["authenticated"]), SimpleUser(username) def homepage(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) @requires("authenticated") async def dashboard(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) @requires("authenticated", redirect="homepage") async def admin(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) @requires("authenticated") def dashboard_sync(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) class Dashboard(HTTPEndpoint): @requires("authenticated") def get(self, request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) @requires("authenticated", redirect="homepage") def admin_sync(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) @requires("authenticated") async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json( { "authenticated": websocket.user.is_authenticated, "user": websocket.user.display_name, } ) def async_inject_decorator( **kwargs: Any, ) -> Callable[[AsyncEndpoint], Callable[..., Awaitable[Response]]]: def wrapper(endpoint: AsyncEndpoint) -> Callable[..., Awaitable[Response]]: async def app(request: Request) -> Response: return await endpoint(request=request, **kwargs) return app return wrapper @async_inject_decorator(additional="payload") @requires("authenticated") async def decorated_async(request: Request, additional: str) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, "additional": additional, } ) def sync_inject_decorator( **kwargs: Any, ) -> Callable[[SyncEndpoint], Callable[..., Response]]: def wrapper(endpoint: SyncEndpoint) -> Callable[..., Response]: def app(request: Request) -> Response: return endpoint(request=request, **kwargs) return app return wrapper @sync_inject_decorator(additional="payload") @requires("authenticated") def decorated_sync(request: Request, additional: str) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, "additional": additional, } ) def ws_inject_decorator(**kwargs: Any) -> Callable[..., AsyncEndpoint]: def wrapper(endpoint: AsyncEndpoint) -> AsyncEndpoint: def app(websocket: WebSocket) -> Awaitable[Response]: return endpoint(websocket=websocket, **kwargs) return app return wrapper @ws_inject_decorator(additional="payload") @requires("authenticated") async def websocket_endpoint_decorated(websocket: WebSocket, additional: str) -> None: await websocket.accept() await websocket.send_json( { "authenticated": websocket.user.is_authenticated, "user": websocket.user.display_name, "additional": additional, } ) app = Starlette( middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth())], routes=[ Route("/", endpoint=homepage), Route("/dashboard", endpoint=dashboard), Route("/admin", endpoint=admin), Route("/dashboard/sync", endpoint=dashboard_sync), Route("/dashboard/class", endpoint=Dashboard), Route("/admin/sync", endpoint=admin_sync), Route("/dashboard/decorated", endpoint=decorated_async), Route("/dashboard/decorated/sync", endpoint=decorated_sync), WebSocketRoute("/ws", endpoint=websocket_endpoint), WebSocketRoute("/ws/decorated", endpoint=websocket_endpoint_decorated), ], ) def test_invalid_decorator_usage() -> None: with pytest.raises(Exception): @requires("authenticated") def foo() -> None: # pragma: no cover pass def test_user_interface(test_client_factory: TestClientFactory) -> None: with test_client_factory(app) as client: response = client.get("/") assert response.status_code == 200 assert response.json() == {"authenticated": False, "user": ""} response = client.get("/", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} def test_authentication_required(test_client_factory: TestClientFactory) -> None: with test_client_factory(app) as client: response = client.get("/dashboard") assert response.status_code == 403 response = client.get("/dashboard", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} response = client.get("/dashboard/sync") assert response.status_code == 403 response = client.get("/dashboard/sync", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} response = client.get("/dashboard/class") assert response.status_code == 403 response = client.get("/dashboard/class", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} response = client.get("/dashboard/decorated", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == { "authenticated": True, "user": "tomchristie", "additional": "payload", } response = client.get("/dashboard/decorated") assert response.status_code == 403 response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == { "authenticated": True, "user": "tomchristie", "additional": "payload", } response = client.get("/dashboard/decorated/sync") assert response.status_code == 403 response = client.get("/dashboard", headers={"Authorization": "basic foobar"}) assert response.status_code == 400 assert response.text == "Invalid basic auth credentials" def test_websocket_authentication_required( test_client_factory: TestClientFactory, ) -> None: with test_client_factory(app) as client: with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws"): pass # pragma: no cover with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}): pass # pragma: no cover with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket: data = websocket.receive_json() assert data == {"authenticated": True, "user": "tomchristie"} with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws/decorated"): pass # pragma: no cover with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}): pass # pragma: no cover with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket: data = websocket.receive_json() assert data == { "authenticated": True, "user": "tomchristie", "additional": "payload", } def test_authentication_redirect(test_client_factory: TestClientFactory) -> None: with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"})) assert response.url == url response = client.get("/admin", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} response = client.get("/admin/sync") assert response.status_code == 200 url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"})) assert response.url == url response = client.get("/admin/sync", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse: return JSONResponse({"error": str(exc)}, status_code=401) @requires("authenticated") def control_panel(request: Request) -> JSONResponse: return JSONResponse( { "authenticated": request.user.is_authenticated, "user": request.user.display_name, } ) other_app = Starlette( routes=[Route("/control-panel", control_panel)], middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)], ) def test_custom_on_error(test_client_factory: TestClientFactory) -> None: with test_client_factory(other_app) as client: response = client.get("/control-panel", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} response = client.get("/control-panel", headers={"Authorization": "basic foobar"}) assert response.status_code == 401 assert response.json() == {"error": "Invalid basic auth credentials"} ================================================ FILE: tests/test_background.py ================================================ import pytest from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response from starlette.types import Receive, Scope, Send from tests.types import TestClientFactory def test_async_task(test_client_factory: TestClientFactory) -> None: TASK_COMPLETE = False async def async_task() -> None: nonlocal TASK_COMPLETE TASK_COMPLETE = True task = BackgroundTask(async_task) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE def test_sync_task(test_client_factory: TestClientFactory) -> None: TASK_COMPLETE = False def sync_task() -> None: nonlocal TASK_COMPLETE TASK_COMPLETE = True task = BackgroundTask(sync_task) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE def test_multiple_tasks(test_client_factory: TestClientFactory) -> None: TASK_COUNTER = 0 def increment(amount: int) -> None: nonlocal TASK_COUNTER TASK_COUNTER += amount async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() tasks.add_task(increment, amount=1) tasks.add_task(increment, amount=2) tasks.add_task(increment, amount=3) response = Response("tasks initiated", media_type="text/plain", background=tasks) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "tasks initiated" assert TASK_COUNTER == 1 + 2 + 3 def test_multi_tasks_failure_avoids_next_execution( test_client_factory: TestClientFactory, ) -> None: TASK_COUNTER = 0 def increment() -> None: nonlocal TASK_COUNTER TASK_COUNTER += 1 if TASK_COUNTER == 1: # pragma: no branch raise Exception("task failed") async def app(scope: Scope, receive: Receive, send: Send) -> None: tasks = BackgroundTasks() tasks.add_task(increment) tasks.add_task(increment) response = Response("tasks initiated", media_type="text/plain", background=tasks) await response(scope, receive, send) client = test_client_factory(app) with pytest.raises(Exception): client.get("/") assert TASK_COUNTER == 1 ================================================ FILE: tests/test_concurrency.py ================================================ from collections.abc import Iterator from contextvars import ContextVar import anyio import pytest from starlette.applications import Starlette from starlette.concurrency import iterate_in_threadpool, run_until_first_complete from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route from tests.types import TestClientFactory @pytest.mark.anyio async def test_run_until_first_complete() -> None: task1_finished = anyio.Event() task2_finished = anyio.Event() async def task1() -> None: task1_finished.set() async def task2() -> None: await task1_finished.wait() await anyio.sleep(0) # pragma: no cover task2_finished.set() # pragma: no cover await run_until_first_complete((task1, {}), (task2, {})) assert task1_finished.is_set() assert not task2_finished.is_set() def test_accessing_context_from_threaded_sync_endpoint( test_client_factory: TestClientFactory, ) -> None: ctxvar: ContextVar[bytes] = ContextVar("ctxvar") ctxvar.set(b"data") def endpoint(request: Request) -> Response: return Response(ctxvar.get()) app = Starlette(routes=[Route("/", endpoint)]) client = test_client_factory(app) resp = client.get("/") assert resp.content == b"data" @pytest.mark.anyio async def test_iterate_in_threadpool() -> None: class CustomIterable: def __iter__(self) -> Iterator[int]: yield from range(3) assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2] ================================================ FILE: tests/test_config.py ================================================ import os from pathlib import Path from typing import Any import pytest from typing_extensions import assert_type from starlette.config import Config, Environ, EnvironError from starlette.datastructures import URL, Secret def test_config_types() -> None: """ We use `assert_type` to test the types returned by Config via mypy. """ config = Config(environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"}) assert_type(config("STR"), str) assert_type(config("STR_DEFAULT", default=""), str) assert_type(config("STR_CAST", cast=str), str) assert_type(config("STR_NONE", default=None), str | None) assert_type(config("STR_CAST_NONE", cast=str, default=None), str | None) assert_type(config("STR_CAST_STR", cast=str, default=""), str) assert_type(config("BOOL", cast=bool), bool) assert_type(config("BOOL_DEFAULT", cast=bool, default=False), bool) assert_type(config("BOOL_NONE", cast=bool, default=None), bool | None) def cast_to_int(v: Any) -> int: return int(v) # our type annotations allow these `cast` and `default` configurations, but # the code will error at runtime. with pytest.raises(ValueError): config("INT_CAST_DEFAULT_STR", cast=cast_to_int, default="true") with pytest.raises(ValueError): config("INT_DEFAULT_STR", cast=int, default="true") def test_config(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None: path = os.path.join(tmpdir, ".env") with open(path, "w") as file: file.write("# Do not commit to source control\n") file.write("DATABASE_URL=postgres://user:pass@localhost/dbname\n") file.write("REQUEST_HOSTNAME=example.com\n") file.write("SECRET_KEY=12345\n") file.write("BOOL_AS_INT=0\n") file.write("\n") file.write("\n") config = Config(path, environ={"DEBUG": "true"}) def cast_to_int(v: Any) -> int: return int(v) DEBUG = config("DEBUG", cast=bool) DATABASE_URL = config("DATABASE_URL", cast=URL) REQUEST_TIMEOUT = config("REQUEST_TIMEOUT", cast=int, default=10) REQUEST_HOSTNAME = config("REQUEST_HOSTNAME") MAIL_HOSTNAME = config("MAIL_HOSTNAME", default=None) SECRET_KEY = config("SECRET_KEY", cast=Secret) UNSET_SECRET = config("UNSET_SECRET", cast=Secret, default=None) EMPTY_SECRET = config("EMPTY_SECRET", cast=Secret, default="") assert config("BOOL_AS_INT", cast=bool) is False assert config("BOOL_AS_INT", cast=cast_to_int) == 0 assert config("DEFAULTED_BOOL", cast=cast_to_int, default=True) == 1 assert DEBUG is True assert DATABASE_URL.path == "/dbname" assert DATABASE_URL.password == "pass" assert DATABASE_URL.username == "user" assert REQUEST_TIMEOUT == 10 assert REQUEST_HOSTNAME == "example.com" assert MAIL_HOSTNAME is None assert repr(SECRET_KEY) == "Secret('**********')" assert str(SECRET_KEY) == "12345" assert bool(SECRET_KEY) assert not bool(EMPTY_SECRET) assert not bool(UNSET_SECRET) with pytest.raises(KeyError): config.get("MISSING") with pytest.raises(ValueError): config.get("DEBUG", cast=int) with pytest.raises(ValueError): config.get("REQUEST_HOSTNAME", cast=bool) config = Config(Path(path)) REQUEST_HOSTNAME = config("REQUEST_HOSTNAME") assert REQUEST_HOSTNAME == "example.com" config = Config() monkeypatch.setenv("STARLETTE_EXAMPLE_TEST", "123") monkeypatch.setenv("BOOL_AS_INT", "1") assert config.get("STARLETTE_EXAMPLE_TEST", cast=int) == 123 assert config.get("BOOL_AS_INT", cast=bool) is True monkeypatch.setenv("BOOL_AS_INT", "2") with pytest.raises(ValueError): config.get("BOOL_AS_INT", cast=bool) def test_missing_env_file_raises(tmpdir: Path) -> None: path = os.path.join(tmpdir, ".env") with pytest.warns(UserWarning, match=f"Config file '{path}' not found."): Config(path) def test_environ() -> None: environ = Environ() # We can mutate the environ at this point. environ["TESTING"] = "True" environ["GONE"] = "123" del environ["GONE"] # We can read the environ. assert environ["TESTING"] == "True" assert "GONE" not in environ # We cannot mutate these keys now that we've read them. with pytest.raises(EnvironError): environ["TESTING"] = "False" with pytest.raises(EnvironError): del environ["GONE"] # Test coverage of abstract methods for MutableMapping. environ = Environ() assert list(iter(environ)) == list(iter(os.environ)) assert len(environ) == len(os.environ) def test_config_with_env_prefix(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None: config = Config(environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_") assert config.get("DEBUG") == "value" with pytest.raises(KeyError): config.get("ENVIRONMENT") def test_config_with_encoding(tmpdir: Path) -> None: path = tmpdir / ".env" path.write_text("MESSAGE=Hello 世界\n", encoding="utf-8") config = Config(path, encoding="utf-8") assert config.get("MESSAGE") == "Hello 世界" ================================================ FILE: tests/test_convertors.py ================================================ from collections.abc import Iterator from datetime import datetime from uuid import UUID import pytest from starlette import convertors from starlette.convertors import Convertor, register_url_convertor from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route, Router from tests.types import TestClientFactory @pytest.fixture(scope="module", autouse=True) def refresh_convertor_types() -> Iterator[None]: convert_types = convertors.CONVERTOR_TYPES.copy() yield convertors.CONVERTOR_TYPES = convert_types class DateTimeConvertor(Convertor[datetime]): regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?" def convert(self, value: str) -> datetime: return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S") def to_string(self, value: datetime) -> str: return value.strftime("%Y-%m-%dT%H:%M:%S") @pytest.fixture(scope="function") def app() -> Router: register_url_convertor("datetime", DateTimeConvertor()) def datetime_convertor(request: Request) -> JSONResponse: param = request.path_params["param"] assert isinstance(param, datetime) return JSONResponse({"datetime": param.strftime("%Y-%m-%dT%H:%M:%S")}) return Router( routes=[ Route( "/datetime/{param:datetime}", endpoint=datetime_convertor, name="datetime-convertor", ) ] ) def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None: client = test_client_factory(app) response = client.get("/datetime/2020-01-01T00:00:00") assert response.json() == {"datetime": "2020-01-01T00:00:00"} assert ( app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) == "/datetime/1996-01-22T23:00:00" ) @pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)]) def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None: def float_convertor(request: Request) -> JSONResponse: param = request.path_params["param"] assert isinstance(param, float) return JSONResponse({"float": param}) app = Router(routes=[Route("/{param:float}", endpoint=float_convertor)]) client = test_client_factory(app) response = client.get(f"/{param}") assert response.status_code == status_code @pytest.mark.parametrize( "param, status_code", [ ("00000000-aaaa-ffff-9999-000000000000", 200), ("00000000aaaaffff9999000000000000", 200), ("00000000-AAAA-FFFF-9999-000000000000", 200), ("00000000AAAAFFFF9999000000000000", 200), ("not-a-uuid", 404), ], ) def test_default_uuid_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None: def uuid_convertor(request: Request) -> JSONResponse: param = request.path_params["param"] assert isinstance(param, UUID) return JSONResponse("ok") app = Router(routes=[Route("/{param:uuid}", endpoint=uuid_convertor)]) client = test_client_factory(app) response = client.get(f"/{param}") assert response.status_code == status_code ================================================ FILE: tests/test_datastructures.py ================================================ import io from tempfile import SpooledTemporaryFile from typing import BinaryIO import pytest from starlette.datastructures import ( URL, CommaSeparatedStrings, FormData, Headers, MultiDict, MutableHeaders, QueryParams, UploadFile, ) def test_url() -> None: u = URL("https://example.org:123/path/to/somewhere?abc=123#anchor") assert u.scheme == "https" assert u.hostname == "example.org" assert u.port == 123 assert u.netloc == "example.org:123" assert u.username is None assert u.password is None assert u.path == "/path/to/somewhere" assert u.query == "abc=123" assert u.fragment == "anchor" new = u.replace(scheme="http") assert new == "http://example.org:123/path/to/somewhere?abc=123#anchor" assert new.scheme == "http" new = u.replace(port=None) assert new == "https://example.org/path/to/somewhere?abc=123#anchor" assert new.port is None new = u.replace(hostname="example.com") assert new == "https://example.com:123/path/to/somewhere?abc=123#anchor" assert new.hostname == "example.com" ipv6_url = URL("https://[fe::2]:12345") new = ipv6_url.replace(port=8080) assert new == "https://[fe::2]:8080" new = ipv6_url.replace(username="username", password="password") assert new == "https://username:password@[fe::2]:12345" assert new.netloc == "username:password@[fe::2]:12345" ipv6_url = URL("https://[fe::2]") new = ipv6_url.replace(port=123) assert new == "https://[fe::2]:123" url = URL("http://u:p@host/") assert url.replace(hostname="bar") == URL("http://u:p@bar/") url = URL("http://u:p@host:80") assert url.replace(port=88) == URL("http://u:p@host:88") url = URL("http://host:80") assert url.replace(username="u") == URL("http://u@host:80") def test_url_query_params() -> None: u = URL("https://example.org/path/?page=3") assert u.query == "page=3" u = u.include_query_params(page=4) assert str(u) == "https://example.org/path/?page=4" u = u.include_query_params(search="testing") assert str(u) == "https://example.org/path/?page=4&search=testing" u = u.replace_query_params(order="name") assert str(u) == "https://example.org/path/?order=name" u = u.remove_query_params("order") assert str(u) == "https://example.org/path/" u = u.include_query_params(page=4, search="testing") assert str(u) == "https://example.org/path/?page=4&search=testing" u = u.remove_query_params(["page", "search"]) assert str(u) == "https://example.org/path/" def test_hidden_password() -> None: u = URL("https://example.org/path/to/somewhere") assert repr(u) == "URL('https://example.org/path/to/somewhere')" u = URL("https://username@example.org/path/to/somewhere") assert repr(u) == "URL('https://username@example.org/path/to/somewhere')" u = URL("https://username:password@example.org/path/to/somewhere") assert repr(u) == "URL('https://username:********@example.org/path/to/somewhere')" def test_csv() -> None: csv = CommaSeparatedStrings('"localhost", "127.0.0.1", 0.0.0.0') assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" assert csv[0] == "localhost" assert len(csv) == 3 csv = CommaSeparatedStrings("'localhost', '127.0.0.1', 0.0.0.0") assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" csv = CommaSeparatedStrings("localhost, 127.0.0.1, 0.0.0.0") assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" csv = CommaSeparatedStrings(["localhost", "127.0.0.1", "0.0.0.0"]) assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" def test_url_from_scope() -> None: u = URL(scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}) assert u == "/path/to/somewhere?abc=123" assert repr(u) == "URL('/path/to/somewhere?abc=123')" u = URL( scope={ "scheme": "https", "server": ("example.org", 123), "path": "/path/to/somewhere", "query_string": b"abc=123", "headers": [], } ) assert u == "https://example.org:123/path/to/somewhere?abc=123" assert repr(u) == "URL('https://example.org:123/path/to/somewhere?abc=123')" u = URL( scope={ "scheme": "https", "server": ("example.org", 443), "path": "/path/to/somewhere", "query_string": b"abc=123", "headers": [], } ) assert u == "https://example.org/path/to/somewhere?abc=123" assert repr(u) == "URL('https://example.org/path/to/somewhere?abc=123')" u = URL( scope={ "scheme": "http", "path": "/some/path", "query_string": b"query=string", "headers": [ (b"content-type", b"text/html"), (b"host", b"example.com:8000"), (b"accept", b"text/html"), ], } ) assert u == "http://example.com:8000/some/path?query=string" assert repr(u) == "URL('http://example.com:8000/some/path?query=string')" def test_headers() -> None: h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")]) assert "a" in h assert "A" in h assert "b" in h assert "B" in h assert "c" not in h assert h["a"] == "123" assert h.get("a") == "123" assert h.get("nope", default=None) is None assert h.getlist("a") == ["123", "456"] assert h.keys() == ["a", "a", "b"] assert h.values() == ["123", "456", "789"] assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")] assert list(h) == ["a", "a", "b"] assert dict(h) == {"a": "123", "b": "789"} assert repr(h) == "Headers(raw=[(b'a', b'123'), (b'a', b'456'), (b'b', b'789')])" assert h == Headers(raw=[(b"a", b"123"), (b"b", b"789"), (b"a", b"456")]) assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")] h = Headers({"a": "123", "b": "789"}) assert h["A"] == "123" assert h["B"] == "789" assert h.raw == [(b"a", b"123"), (b"b", b"789")] assert repr(h) == "Headers({'a': '123', 'b': '789'})" def test_mutable_headers() -> None: h = MutableHeaders() assert dict(h) == {} h["a"] = "1" assert dict(h) == {"a": "1"} h["a"] = "2" assert dict(h) == {"a": "2"} h.setdefault("a", "3") assert dict(h) == {"a": "2"} h.setdefault("b", "4") assert dict(h) == {"a": "2", "b": "4"} del h["a"] assert dict(h) == {"b": "4"} assert h.raw == [(b"b", b"4")] def test_mutable_headers_merge() -> None: h = MutableHeaders() h = h | MutableHeaders({"a": "1"}) assert isinstance(h, MutableHeaders) assert dict(h) == {"a": "1"} assert h.items() == [("a", "1")] assert h.raw == [(b"a", b"1")] def test_mutable_headers_merge_dict() -> None: h = MutableHeaders() h = h | {"a": "1"} assert isinstance(h, MutableHeaders) assert dict(h) == {"a": "1"} assert h.items() == [("a", "1")] assert h.raw == [(b"a", b"1")] def test_mutable_headers_update() -> None: h = MutableHeaders() h |= MutableHeaders({"a": "1"}) assert isinstance(h, MutableHeaders) assert dict(h) == {"a": "1"} assert h.items() == [("a", "1")] assert h.raw == [(b"a", b"1")] def test_mutable_headers_update_dict() -> None: h = MutableHeaders() h |= {"a": "1"} assert isinstance(h, MutableHeaders) assert dict(h) == {"a": "1"} assert h.items() == [("a", "1")] assert h.raw == [(b"a", b"1")] def test_mutable_headers_merge_not_mapping() -> None: h = MutableHeaders() with pytest.raises(TypeError): h |= {"not_mapping"} # type: ignore[arg-type] with pytest.raises(TypeError): h | {"not_mapping"} # type: ignore[operator] def test_headers_mutablecopy() -> None: h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")]) c = h.mutablecopy() assert c.items() == [("a", "123"), ("a", "456"), ("b", "789")] c["a"] = "abc" assert c.items() == [("a", "abc"), ("b", "789")] def test_mutable_headers_from_scope() -> None: # "headers" in scope must not necessarily be a list h = MutableHeaders(scope={"headers": ((b"a", b"1"),)}) assert dict(h) == {"a": "1"} h.update({"b": "2"}) assert dict(h) == {"a": "1", "b": "2"} assert list(h.items()) == [("a", "1"), ("b", "2")] assert list(h.raw) == [(b"a", b"1"), (b"b", b"2")] def test_url_blank_params() -> None: q = QueryParams("a=123&abc&def&b=456") assert "a" in q assert "abc" in q assert "def" in q assert "b" in q val = q.get("abc") assert val is not None assert len(val) == 0 assert len(q["a"]) == 3 assert list(q.keys()) == ["a", "abc", "def", "b"] def test_queryparams() -> None: q = QueryParams("a=123&a=456&b=789") assert "a" in q assert "A" not in q assert "c" not in q assert q["a"] == "456" assert q.get("a") == "456" assert q.get("nope", default=None) is None assert q.getlist("a") == ["123", "456"] assert list(q.keys()) == ["a", "b"] assert list(q.values()) == ["456", "789"] assert list(q.items()) == [("a", "456"), ("b", "789")] assert len(q) == 2 assert list(q) == ["a", "b"] assert dict(q) == {"a": "456", "b": "789"} assert str(q) == "a=123&a=456&b=789" assert repr(q) == "QueryParams('a=123&a=456&b=789')" assert QueryParams({"a": "123", "b": "456"}) == QueryParams([("a", "123"), ("b", "456")]) assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456") assert QueryParams({"a": "123", "b": "456"}) == QueryParams({"b": "456", "a": "123"}) assert QueryParams() == QueryParams({}) assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456") assert QueryParams({"a": "123", "b": "456"}) != "invalid" q = QueryParams([("a", "123"), ("a", "456")]) assert QueryParams(q) == q @pytest.mark.anyio async def test_upload_file_file_input() -> None: """Test passing file/stream into the UploadFile constructor""" stream = io.BytesIO(b"data") file = UploadFile(filename="file", file=stream, size=4) assert await file.read() == b"data" assert file.size == 4 await file.write(b" and more data!") assert await file.read() == b"" assert file.size == 19 await file.seek(0) assert await file.read() == b"data and more data!" @pytest.mark.anyio async def test_upload_file_without_size() -> None: """Test passing file/stream into the UploadFile constructor without size""" stream = io.BytesIO(b"data") file = UploadFile(filename="file", file=stream) assert await file.read() == b"data" assert file.size is None await file.write(b" and more data!") assert await file.read() == b"" assert file.size is None await file.seek(0) assert await file.read() == b"data and more data!" @pytest.mark.anyio @pytest.mark.parametrize("max_size", [1, 1024], ids=["rolled", "unrolled"]) async def test_uploadfile_rolling(max_size: int) -> None: """Test that we can r/w to a SpooledTemporaryFile managed by UploadFile before and after it rolls to disk """ stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment] max_size=max_size ) file = UploadFile(filename="file", file=stream, size=0) assert await file.read() == b"" assert file.size == 0 await file.write(b"data") assert await file.read() == b"" assert file.size == 4 await file.seek(0) assert await file.read() == b"data" await file.write(b" more") assert await file.read() == b"" assert file.size == 9 await file.seek(0) assert await file.read() == b"data more" assert file.size == 9 await file.close() def test_formdata() -> None: stream = io.BytesIO(b"data") upload = UploadFile(filename="file", file=stream, size=4) form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) assert "a" in form assert "A" not in form assert "c" not in form assert form["a"] == "456" assert form.get("a") == "456" assert form.get("nope", default=None) is None assert form.getlist("a") == ["123", "456"] assert list(form.keys()) == ["a", "b"] assert list(form.values()) == ["456", upload] assert list(form.items()) == [("a", "456"), ("b", upload)] assert len(form) == 2 assert list(form) == ["a", "b"] assert dict(form) == {"a": "456", "b": upload} assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" assert FormData(form) == form assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} @pytest.mark.anyio async def test_upload_file_repr() -> None: stream = io.BytesIO(b"data") file = UploadFile(filename="file", file=stream, size=4) assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))" @pytest.mark.anyio async def test_upload_file_repr_headers() -> None: stream = io.BytesIO(b"data") file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"})) assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" def test_multidict() -> None: q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) assert "a" in q assert "A" not in q assert "c" not in q assert q["a"] == "456" assert q.get("a") == "456" assert q.get("nope", default=None) is None assert q.getlist("a") == ["123", "456"] assert list(q.keys()) == ["a", "b"] assert list(q.values()) == ["456", "789"] assert list(q.items()) == [("a", "456"), ("b", "789")] assert len(q) == 2 assert list(q) == ["a", "b"] assert dict(q) == {"a": "456", "b": "789"} assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])" assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])" assert MultiDict({"a": "123", "b": "456"}) == MultiDict([("a", "123"), ("b", "456")]) assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"}) assert MultiDict() == MultiDict({}) assert MultiDict({"a": "123", "b": "456"}) != "invalid" q = MultiDict([("a", "123"), ("a", "456")]) assert MultiDict(q) == q q = MultiDict([("a", "123"), ("a", "456")]) q["a"] = "789" assert q["a"] == "789" assert q.get("a") == "789" assert q.getlist("a") == ["789"] q = MultiDict([("a", "123"), ("a", "456")]) del q["a"] assert q.get("a") is None assert repr(q) == "MultiDict([])" q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) assert q.pop("a") == "456" assert q.get("a", None) is None assert repr(q) == "MultiDict([('b', '789')])" q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) item = q.popitem() assert q.get(item[0]) is None q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) assert q.poplist("a") == ["123", "456"] assert q.get("a") is None assert repr(q) == "MultiDict([('b', '789')])" q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")]) q.clear() assert q.get("a") is None assert repr(q) == "MultiDict([])" q = MultiDict([("a", "123")]) q.setlist("a", ["456", "789"]) assert q.getlist("a") == ["456", "789"] q.setlist("b", []) assert "b" not in q q = MultiDict([("a", "123")]) assert q.setdefault("a", "456") == "123" assert q.getlist("a") == ["123"] assert q.setdefault("b", "456") == "456" assert q.getlist("b") == ["456"] assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" q = MultiDict([("a", "123")]) q.append("a", "456") assert q.getlist("a") == ["123", "456"] assert repr(q) == "MultiDict([('a', '123'), ('a', '456')])" q = MultiDict([("a", "123"), ("b", "456")]) q.update({"a": "789"}) assert q.getlist("a") == ["789"] assert q == MultiDict([("a", "789"), ("b", "456")]) q = MultiDict([("a", "123"), ("b", "456")]) q.update(q) assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" q = MultiDict([("a", "123"), ("a", "456")]) q.update([("a", "123")]) assert q.getlist("a") == ["123"] q.update([("a", "456")], a="789", b="123") assert q == MultiDict([("a", "456"), ("a", "789"), ("b", "123")]) ================================================ FILE: tests/test_endpoints.py ================================================ from collections.abc import Iterator import pytest from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route, Router from starlette.testclient import TestClient from starlette.websockets import WebSocket from tests.types import TestClientFactory class Homepage(HTTPEndpoint): async def get(self, request: Request) -> PlainTextResponse: username = request.path_params.get("username") if username is None: return PlainTextResponse("Hello, world!") return PlainTextResponse(f"Hello, {username}!") app = Router(routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)]) @pytest.fixture def client(test_client_factory: TestClientFactory) -> Iterator[TestClient]: with test_client_factory(app) as client: yield client def test_http_endpoint_route(client: TestClient) -> None: response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world!" def test_http_endpoint_route_path_params(client: TestClient) -> None: response = client.get("/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" def test_http_endpoint_route_method(client: TestClient) -> None: response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" assert response.headers["allow"] == "GET" def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None: class WebSocketApp(WebSocketEndpoint): async def on_connect(self, websocket: WebSocket) -> None: assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" def test_websocket_endpoint_on_receive_bytes( test_client_factory: TestClientFactory, ) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "bytes" async def on_receive(self, websocket: WebSocket, data: bytes) -> None: await websocket.send_bytes(b"Message bytes was: " + data) client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_bytes(b"Hello, world!") _bytes = websocket.receive_bytes() assert _bytes == b"Message bytes was: Hello, world!" with pytest.raises(RuntimeError): with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello world") def test_websocket_endpoint_on_receive_json( test_client_factory: TestClientFactory, ) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_json({"message": data}) client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} with pytest.raises(RuntimeError): with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello world") def test_websocket_endpoint_on_receive_json_binary( test_client_factory: TestClientFactory, ) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_json({"message": data}, mode="binary") client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"message": {"hello": "world"}} def test_websocket_endpoint_on_receive_text( test_client_factory: TestClientFactory, ) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "text" async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_text(f"Message text was: {data}") client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() assert _text == "Message text was: Hello, world!" with pytest.raises(RuntimeError): with client.websocket_connect("/ws") as websocket: websocket.send_bytes(b"Hello world") def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None: class WebSocketApp(WebSocketEndpoint): encoding = None async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_text(f"Message text was: {data}") client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() assert _text == "Message text was: Hello, world!" def test_websocket_endpoint_on_disconnect( test_client_factory: TestClientFactory, ) -> None: class WebSocketApp(WebSocketEndpoint): async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: assert close_code == 1001 await websocket.close(code=close_code) client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.close(code=1001) ================================================ FILE: tests/test_exceptions.py ================================================ from collections.abc import Generator from typing import Any import pytest from pytest import MonkeyPatch from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send from tests.types import TestClientFactory def raise_runtime_error(request: Request) -> None: raise RuntimeError("Yikes") def not_acceptable(request: Request) -> None: raise HTTPException(status_code=406) def no_content(request: Request) -> None: raise HTTPException(status_code=204) def not_modified(request: Request) -> None: raise HTTPException(status_code=304) def with_headers(request: Request) -> None: raise HTTPException(status_code=200, headers={"x-potato": "always"}) class BadBodyException(HTTPException): pass async def read_body_and_raise_exc(request: Request) -> None: await request.body() raise BadBodyException(422) async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse: body = await request.body() return JSONResponse(status_code=422, content={"body": body.decode()}) class HandledExcAfterResponse: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = PlainTextResponse("OK", status_code=200) await response(scope, receive, send) raise HTTPException(status_code=406) router = Router( routes=[ Route("/runtime_error", endpoint=raise_runtime_error), Route("/not_acceptable", endpoint=not_acceptable), Route("/no_content", endpoint=no_content), Route("/not_modified", endpoint=not_modified), Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), ] ) app = ExceptionMiddleware( router, handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item] ) @pytest.fixture def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client def test_not_acceptable(client: TestClient) -> None: response = client.get("/not_acceptable") assert response.status_code == 406 assert response.text == "Not Acceptable" def test_no_content(client: TestClient) -> None: response = client.get("/no_content") assert response.status_code == 204 assert "content-length" not in response.headers def test_not_modified(client: TestClient) -> None: response = client.get("/not_modified") assert response.status_code == 304 assert response.text == "" def test_with_headers(client: TestClient) -> None: response = client.get("/with_headers") assert response.status_code == 200 assert response.headers["x-potato"] == "always" def test_websockets_should_raise(client: TestClient) -> None: with pytest.raises(RuntimeError): with client.websocket_connect("/runtime_error"): pass # pragma: no cover def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None: # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."): client.get("/handled_exc_after_response") # If `raise_server_exceptions=False` then the test client will still allow # us to see the response as it will have been seen by the client. allow_200_client = test_client_factory(app, raise_server_exceptions=False) response = allow_200_client.get("/handled_exc_after_response") assert response.status_code == 200 assert response.text == "OK" def test_force_500_response(test_client_factory: TestClientFactory) -> None: # use a sentinel variable to make sure we actually # make it into the endpoint and don't get a 500 # from an incorrect ASGI app signature or something called = False async def app(scope: Scope, receive: Receive, send: Send) -> None: nonlocal called called = True raise RuntimeError() force_500_client = test_client_factory(app, raise_server_exceptions=False) response = force_500_client.get("/") assert called assert response.status_code == 500 assert response.text == "" def test_http_str() -> None: assert str(HTTPException(status_code=404)) == "404: Not Found" assert str(HTTPException(404, "Not Found: foo")) == "404: Not Found: foo" assert str(HTTPException(404, headers={"key": "value"})) == "404: Not Found" def test_http_repr() -> None: assert repr(HTTPException(404)) == ("HTTPException(status_code=404, detail='Not Found')") assert repr(HTTPException(404, detail="Not Found: foo")) == ( "HTTPException(status_code=404, detail='Not Found: foo')" ) class CustomHTTPException(HTTPException): pass assert repr(CustomHTTPException(500, detail="Something custom")) == ( "CustomHTTPException(status_code=500, detail='Something custom')" ) def test_websocket_str() -> None: assert str(WebSocketException(1008)) == "1008: " assert str(WebSocketException(1008, "Policy Violation")) == "1008: Policy Violation" def test_websocket_repr() -> None: assert repr(WebSocketException(1008, reason="Policy Violation")) == ( "WebSocketException(code=1008, reason='Policy Violation')" ) class CustomWebSocketException(WebSocketException): pass assert ( repr(CustomWebSocketException(1013, reason="Something custom")) == "CustomWebSocketException(code=1013, reason='Something custom')" ) def test_request_in_app_and_handler_is_the_same_object(client: TestClient) -> None: response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!") assert response.status_code == 422 assert response.json() == {"body": "Hello!"} def test_http_exception_does_not_use_threadpool(client: TestClient, monkeypatch: MonkeyPatch) -> None: """ Verify that handling HTTPException does not invoke run_in_threadpool, confirming the handler correctly runs in the main async context. """ from starlette import _exception_handler # Replace run_in_threadpool with a function that raises an error def mock_run_in_threadpool(*args: Any, **kwargs: Any) -> None: pytest.fail("run_in_threadpool should not be called for HTTP exceptions") # pragma: no cover # Apply the monkeypatch only during this test monkeypatch.setattr(_exception_handler, "run_in_threadpool", mock_run_in_threadpool) # This should succeed because http_exception is async and won't use run_in_threadpool response = client.get("/not_acceptable") assert response.status_code == 406 def test_handlers_annotations() -> None: """Check that async exception handlers are accepted by type checkers. We annotate the handlers' exceptions with plain `Exception` to avoid variance issues when using other exception types. """ async def async_catch_all_handler(request: Request, exc: Exception) -> JSONResponse: raise NotImplementedError def sync_catch_all_handler(request: Request, exc: Exception) -> JSONResponse: raise NotImplementedError ExceptionMiddleware(router, handlers={Exception: sync_catch_all_handler}) ExceptionMiddleware(router, handlers={Exception: async_catch_all_handler}) ================================================ FILE: tests/test_formparsers.py ================================================ from __future__ import annotations import os import threading from collections.abc import Generator from contextlib import AbstractContextManager, nullcontext as does_not_raise from io import BytesIO from pathlib import Path from tempfile import SpooledTemporaryFile from typing import Any, ClassVar from unittest import mock import pytest from starlette.applications import Starlette from starlette.datastructures import UploadFile from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount from starlette.types import ASGIApp, Receive, Scope, Send from tests.types import TestClientFactory class ForceMultipartDict(dict[Any, Any]): def __bool__(self) -> bool: return True # FORCE_MULTIPART is an empty dict that boolean-evaluates as `True`. FORCE_MULTIPART = ForceMultipartDict() async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() output: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() output[key] = { "filename": value.filename, "size": value.size, "content": content.decode(), "content_type": value.content_type, } else: output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() output: dict[str, list[Any]] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] if isinstance(value, UploadFile): content = await value.read() output[key].append( { "filename": value.filename, "size": value.size, "content": content.decode(), "content_type": value.content_type, } ) else: output[key].append(value) await request.close() response = JSONResponse(output) await response(scope, receive, send) async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() output: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() output[key] = { "filename": value.filename, "size": value.size, "content": content.decode(), "content_type": value.content_type, "headers": list(value.headers.items()), } else: output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) # Read bytes, to force request.stream() to return the already parsed body await request.body() data = await request.form() output = {} for key, value in data.items(): output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None: """Helper app to monitor what thread the app was called on. This can later be used to validate thread/event loop operations. """ request = Request(scope, receive) # Make sure we parse the form await request.form() await request.close() # Send back the current thread id response = JSONResponse({"thread_ident": threading.current_thread().ident}) await response(scope, receive, send) def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) output: dict[str, Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() output[key] = { "filename": value.filename, "size": value.size, "content": content.decode(), "content_type": value.content_type, } else: output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) return app def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { "test": { "filename": "test.txt", "size": 14, "content": "", "content_type": "text/plain", } } def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) assert response.json() == { "test": { "filename": "test.txt", "size": 14, "content": "", "content_type": "text/plain", } } def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") path2 = os.path.join(tmpdir, "test2.txt") with open(path2, "wb") as file: file.write(b"") client = test_client_factory(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}) assert response.json() == { "test1": { "filename": "test1.txt", "size": 15, "content": "", "content_type": "text/plain", }, "test2": { "filename": "test2.txt", "size": 15, "content": "", "content_type": "text/plain", }, } def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") path2 = os.path.join(tmpdir, "test2.txt") with open(path2, "wb") as file: file.write(b"") client = test_client_factory(app_with_headers) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", files=[ ("test1", (None, f1)), ("test2", ("test2.txt", f2, "text/plain", {"x-custom": "f2"})), ], ) assert response.json() == { "test1": "", "test2": { "filename": "test2.txt", "size": 15, "content": "", "content_type": "text/plain", "headers": [ [ "content-disposition", 'form-data; name="test2"; filename="test2.txt"', ], ["x-custom", "f2"], ["content-type", "text/plain"], ], }, } def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") path2 = os.path.join(tmpdir, "test2.txt") with open(path2, "wb") as file: file.write(b"") client = test_client_factory(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", data={"test1": "abc"}, files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) assert response.json() == { "test1": [ "abc", { "filename": "test1.txt", "size": 15, "content": "", "content_type": "text/plain", }, { "filename": "test2.txt", "size": 15, "content": "", "content_type": "text/plain", }, ] } def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", data=( # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore b'Content-Disposition: form-data; name="field0"\r\n\r\n' b"value0\r\n" # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="file.txt"\r\n' b"Content-Type: text/plain\r\n\r\n" b"\r\n" # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="field1"\r\n\r\n' b"value1\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert response.json() == { "file": { "filename": "file.txt", "size": 14, "content": "", "content_type": "text/plain", }, "field0": "value0", "field1": "value1", } class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): """Helper class to track which threads performed the rollover operation. This is not threadsafe/multi-test safe. """ rollover_threads: ClassVar[set[int | None]] = set() def rollover(self) -> None: ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) super().rollover() @pytest.fixture def mock_spooled_temporary_file() -> Generator[None]: try: with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): yield finally: ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() def test_multipart_request_large_file_rollover_in_background_thread( mock_spooled_temporary_file: None, test_client_factory: TestClientFactory ) -> None: """Test that Spooled file rollovers happen in background threads.""" data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) client = test_client_factory(app_monitor_thread) response = client.post("/", files=[("test_large", data)]) assert response.status_code == 200 # Parse the event thread id from the API response and ensure we have one app_thread_ident = response.json().get("thread_ident") assert app_thread_ident is not None # Ensure the app thread was not the same as the rollover one and that a rollover thread exists assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' b"Content-Type: text/plain\r\n\r\n" b"\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert response.json() == { "file": { "filename": "文書.txt", "size": 14, "content": "", "content_type": "text/plain", } } def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' b"Content-Type: image/jpeg\r\n\r\n" b"\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert response.json() == { "file": { "filename": "画像.jpg", "size": 14, "content": "", "content_type": "image/jpeg", } } def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( "/", data=( b"--20b303e711c4ab8c443184ac833ab00f\r\n" # type: ignore b"Content-Disposition: form-data; " b'name="value"\r\n\r\n' b"Transf\xc3\xa9rer\r\n" b"--20b303e711c4ab8c443184ac833ab00f--\r\n" ), headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")}, ) assert response.json() == {"value": "Transférer"} def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/") assert response.json() == {} def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data", "second": "key pair"} def test_user_safe_decode_helper() -> None: result = _user_safe_decode(b"\xc4\x99\xc5\xbc\xc4\x87", "utf-8") assert result == "ężć" def test_user_safe_decode_ignores_wrong_charset() -> None: result = _user_safe_decode(b"abc", "latin-8") assert result == "abc" @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_missing_boundary_parameter( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) with expectation: res = client.post( "/", data=( # file b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore b"Content-Type: text/plain\r\n\r\n" b"\r\n" ), headers={"Content-Type": "multipart/form-data; charset=utf-8"}, ) assert res.status_code == 400 assert res.text == "Missing boundary in multipart." @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_missing_name_parameter_on_content_disposition( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) with expectation: res = client.post( "/", data=( # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore b'Content-Disposition: form-data; ="field0"\r\n\r\n' b"value0\r\n" ), headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")}, ) assert res.status_code == 400 assert res.text == 'The Content-Disposition header field "name" must be provided.' @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_too_many_fields_raise( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) fields = [] for i in range(1001): fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") with expectation: res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 400 assert res.text == "Too many fields. Maximum number of fields is 1000." @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_too_many_files_raise( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) fields = [] for i in range(1001): fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") with expectation: res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 400 assert res.text == "Too many files. Maximum number of files is 1000." @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_too_many_files_single_field_raise( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) fields = [] for i in range(1001): # This uses the same field name "N" for all files, equivalent to a # multifile upload form field fields.append(f'--B\r\nContent-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") with expectation: res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 400 assert res.text == "Too many files. Maximum number of files is 1000." @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_too_many_files_and_fields_raise( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) fields = [] for i in range(1001): fields.append(f'--B\r\nContent-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n\r\n') fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") with expectation: res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 400 assert res.text == "Too many files. Maximum number of files is 1000." @pytest.mark.parametrize( "app,expectation", [ (make_app_max_parts(max_fields=1), pytest.raises(MultiPartException)), ( Starlette(routes=[Mount("/", app=make_app_max_parts(max_fields=1))]), does_not_raise(), ), ], ) def test_max_fields_is_customizable_low_raises( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) fields = [] for i in range(2): fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") with expectation: res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 400 assert res.text == "Too many fields. Maximum number of fields is 1." @pytest.mark.parametrize( "app,expectation", [ (make_app_max_parts(max_files=1), pytest.raises(MultiPartException)), ( Starlette(routes=[Mount("/", app=make_app_max_parts(max_files=1))]), does_not_raise(), ), ], ) def test_max_files_is_customizable_low_raises( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) fields = [] for i in range(2): fields.append(f'--B\r\nContent-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") with expectation: res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 400 assert res.text == "Too many files. Maximum number of files is 1." def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) -> None: client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000)) fields = [] for i in range(2000): fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n') fields.append(f'--B\r\nContent-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n\r\n') data = "".join(fields).encode("utf-8") data += b"--B--\r\n" res = client.post( "/", data=data, # type: ignore headers={"Content-Type": ("multipart/form-data; boundary=B")}, ) assert res.status_code == 200 res_data = res.json() assert res_data["N1999"] == "" assert res_data["F1999"] == { "filename": "F1999", "size": 0, "content": "", "content_type": None, } @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_max_part_size_exceeds_limit( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR" multipart_data = ( f"--{boundary}\r\n" f'Content-Disposition: form-data; name="small"\r\n\r\n' "small content\r\n" f"--{boundary}\r\n" f'Content-Disposition: form-data; name="large"\r\n\r\n' + ("x" * 1024 * 1024 + "x") # 1MB + 1 byte of data + "\r\n" f"--{boundary}--\r\n" ).encode("utf-8") headers = { "Content-Type": f"multipart/form-data; boundary={boundary}", "Transfer-Encoding": "chunked", } with expectation: response = client.post("/", data=multipart_data, headers=headers) # type: ignore assert response.status_code == 400 assert response.text == "Part exceeded maximum size of 1024KB." @pytest.mark.parametrize( "app,expectation", [ (make_app_max_parts(max_part_size=1024 * 10), pytest.raises(MultiPartException)), ( Starlette(routes=[Mount("/", app=make_app_max_parts(max_part_size=1024 * 10))]), does_not_raise(), ), ], ) def test_max_part_size_exceeds_custom_limit( app: ASGIApp, expectation: AbstractContextManager[Exception], test_client_factory: TestClientFactory, ) -> None: client = test_client_factory(app) boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR" multipart_data = ( f"--{boundary}\r\n" f'Content-Disposition: form-data; name="small"\r\n\r\n' "small content\r\n" f"--{boundary}\r\n" f'Content-Disposition: form-data; name="large"\r\n\r\n' + ("x" * 1024 * 10 + "x") # 1MB + 1 byte of data + "\r\n" f"--{boundary}--\r\n" ).encode("utf-8") headers = { "Content-Type": f"multipart/form-data; boundary={boundary}", "Transfer-Encoding": "chunked", } with expectation: response = client.post("/", content=multipart_data, headers=headers) assert response.status_code == 400 assert response.text == "Part exceeded maximum size of 10KB." ================================================ FILE: tests/test_requests.py ================================================ from __future__ import annotations import sys from collections.abc import Iterator from typing import Any import anyio import pytest from starlette.datastructures import URL, Address, State from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.types import Message, Receive, Scope, Send from tests.types import TestClientFactory def test_request_url(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = {"method": request.method, "url": str(request.url)} response = JSONResponse(data) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} response = client.get("https://example.org:123/") assert response.json() == {"method": "GET", "url": "https://example.org:123/"} def test_request_query_params(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) params = dict(request.query_params) response = JSONResponse({"params": params}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/?a=123&b=456") assert response.json() == {"params": {"a": "123", "b": "456"}} @pytest.mark.skipif( any(module in sys.modules for module in ("brotli", "brotlicffi")), reason='urllib3 includes "br" to the "accept-encoding" headers.', ) def test_request_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) headers = dict(request.headers) response = JSONResponse({"headers": headers}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/", headers={"host": "example.org"}) assert response.json() == { "headers": { "host": "example.org", "user-agent": "testclient", "accept-encoding": "gzip, deflate", "accept": "*/*", "connection": "keep-alive", } } @pytest.mark.parametrize( "scope,expected_client", [ ({"client": ["client", 42]}, Address("client", 42)), ({"client": None}, None), ({}, None), ], ) def test_request_client(scope: Scope, expected_client: Address | None) -> None: scope.update({"type": "http"}) # required by Request's constructor client = Request(scope).client assert client == expected_client def test_request_body(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} response = client.post("/", json={"a": "123"}) assert response.json() == {"body": '{"a":"123"}'} response = client.post("/", data="abc") # type: ignore assert response.json() == {"body": "abc"} def test_request_stream(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = b"" async for chunk in request.stream(): body += chunk response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} response = client.post("/", json={"a": "123"}) assert response.json() == {"body": '{"a":"123"}'} response = client.post("/", data="abc") # type: ignore assert response.json() == {"body": "abc"} def test_request_form_urlencoded(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) form = await request.form() response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) async with request.form() as form: response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} def test_request_body_then_stream(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() chunks = b"" async for chunk in request.stream(): chunks += chunk response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", data="abc") # type: ignore assert response.json() == {"body": "abc", "stream": "abc"} def test_request_stream_then_body(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) chunks = b"" async for chunk in request.stream(): # pragma: no branch chunks += chunk try: body = await request.body() except RuntimeError: body = b"" response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", data="abc") # type: ignore assert response.json() == {"body": "", "stream": "abc"} def test_request_json(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.json() response = JSONResponse({"json": data}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": {"a": "123"}} def test_request_scope_interface() -> None: """ A Request can be instantiated with a scope, and presents a `Mapping` interface. """ request = Request({"type": "http", "method": "GET", "path": "/abc/"}) assert request["method"] == "GET" assert dict(request) == {"type": "http", "method": "GET", "path": "/abc/"} assert len(request) == 3 def test_request_raw_path(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) path = request.scope["path"] raw_path = request.scope["raw_path"] response = PlainTextResponse(f"{path}, {raw_path}") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/he%2Fllo") assert response.text == "/he/llo, b'/he%2Fllo'" def test_request_without_setting_receive( test_client_factory: TestClientFactory, ) -> None: """ If Request is instantiated without the receive channel, then .body() is not available. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope) try: data = await request.json() except RuntimeError: data = "Receive channel not available" response = JSONResponse({"json": data}) await response(scope, receive, send) client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": "Receive channel not available"} def test_request_disconnect( anyio_backend_name: str, anyio_backend_options: dict[str, Any], ) -> None: """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) await request.body() async def receiver() -> Message: return {"type": "http.disconnect"} scope = {"type": "http", "method": "POST", "path": "/"} with pytest.raises(ClientDisconnect): anyio.run( app, # type: ignore scope, receiver, None, backend=anyio_backend_name, backend_options=anyio_backend_options, ) def test_request_is_disconnected(test_client_factory: TestClientFactory) -> None: """ If a client disconnect occurs after reading request body then request will be set disconnected properly. """ disconnected_after_response = None async def app(scope: Scope, receive: Receive, send: Send) -> None: nonlocal disconnected_after_response request = Request(scope, receive) body = await request.body() disconnected = await request.is_disconnected() response = JSONResponse({"body": body.decode(), "disconnected": disconnected}) await response(scope, receive, send) disconnected_after_response = await request.is_disconnected() client = test_client_factory(app) response = client.post("/", content="foo") assert response.json() == {"body": "foo", "disconnected": False} assert disconnected_after_response def test_request_state_object() -> None: scope = {"state": {"old": "foo"}} s = State(scope["state"]) s.new = "value" assert s.new == "value" del s.new with pytest.raises(AttributeError): s.new # Test dictionary-style methods # Test __setitem__ s["dict_key"] = "dict_value" assert s["dict_key"] == "dict_value" assert s.dict_key == "dict_value" # Test __iter__ s["another_key"] = "another_value" keys = list(s) assert "old" in keys assert "dict_key" in keys assert "another_key" in keys # Test __len__ assert len(s) == 3 # Test __delitem__ del s["dict_key"] assert len(s) == 2 with pytest.raises(KeyError): s["dict_key"] def test_request_state(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) request.state.example = 123 response = JSONResponse({"state.example": request.state.example}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"state.example": 123} def test_request_cookies(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) mycookie = request.cookies.get("mycookie") if mycookie: response = Response(mycookie, media_type="text/plain") else: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "Hello, cookies!") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" response = client.get("/") assert response.text == "Hello, cookies!" def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None: """ The following test is based on a cookie set by Okta, a well-known authorization service. It turns out that it's common practice to set cookies that would be invalid according to the spec. """ tough_cookie = ( "provider-oauth-nonce=validAsciiblabla; " 'okta-oauth-redirect-params={"responseType":"code","state":"somestate",' '"nonce":"somenonce","scopes":["openid","profile","email","phone"],' '"urls":{"issuer":"https://subdomain.okta.com/oauth2/authServer",' '"authorizeUrl":"https://subdomain.okta.com/oauth2/authServer/v1/authorize",' '"userinfoUrl":"https://subdomain.okta.com/oauth2/authServer/v1/userinfo"}}; ' "importantCookie=importantValue; sessionCookie=importantSessionValue" ) expected_keys = { "importantCookie", "okta-oauth-redirect-params", "provider-oauth-nonce", "sessionCookie", } async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/", headers={"cookie": tough_cookie}) result = response.json() assert len(result["cookies"]) == 4 assert set(result["cookies"].keys()) == expected_keys # These test cases copied from Tornado's implementation @pytest.mark.parametrize( "set_cookie,expected", [ ("chips=ahoy; vienna=finger", {"chips": "ahoy", "vienna": "finger"}), # all semicolons are delimiters, even within quotes ( 'keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"', {"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'}, ), # Illegal cookies that have an '=' char in an unquoted value. ("keebler=E=mc2", {"keebler": "E=mc2"}), # Cookies with ':' character in their name. ("key:term=value:term", {"key:term": "value:term"}), # Cookies with '[' and ']'. ("a=b; c=[; d=r; f=h", {"a": "b", "c": "[", "d": "r", "f": "h"}), # Cookies that RFC6265 allows. ("a=b; Domain=example.com", {"a": "b", "Domain": "example.com"}), # parse_cookie() keeps only the last cookie with the same name. ("a=b; h=i; a=c", {"a": "c", "h": "i"}), ], ) def test_cookies_edge_cases( set_cookie: str, expected: dict[str, str], test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected @pytest.mark.parametrize( "set_cookie,expected", [ # Chunks without an equals sign appear as unnamed values per # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 ( "abc=def; unnamed; django_language=en", {"": "unnamed", "abc": "def", "django_language": "en"}, ), # Even a double quote may be an unamed value. ('a=b; "; c=d', {"a": "b", "": '"', "c": "d"}), # Spaces in names and values, and an equals sign in values. ("a b c=d e = f; gh=i", {"a b c": "d e = f", "gh": "i"}), # More characters the spec forbids. ('a b,c<>@:/[]?{}=d " =e,f g', {"a b,c<>@:/[]?{}": 'd " =e,f g'}), # Unicode characters. The spec only allows ASCII. # ("saint=André Bessette", {"saint": "André Bessette"}), # Browsers don't send extra whitespace or semicolons in Cookie headers, # but cookie_parser() should parse whitespace the same way # document.cookie parses whitespace. (" = b ; ; = ; c = ; ", {"": "b", "c": ""}), ], ) def test_cookies_invalid( set_cookie: str, expected: dict[str, str], test_client_factory: TestClientFactory, ) -> None: """ Cookie strings that are against the RFC6265 spec but which browsers will send if set via document.cookie. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected def test_multiple_cookie_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: scope["headers"] = [(b"cookie", b"a=abc"), (b"cookie", b"b=def"), (b"cookie", b"c=ghi")] request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") result = response.json() assert result["cookies"] == {"a": "abc", "b": "def", "c": "ghi"} def test_chunked_encoding(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) client = test_client_factory(app) def post_body() -> Iterator[bytes]: yield b"foo" yield b"bar" response = client.post("/", data=post_body()) # type: ignore assert response.json() == {"body": "foobar"} def test_request_send_push_promise(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: # the server is push-enabled scope["extensions"]["http.response.push"] = {} request = Request(scope, receive, send) await request.send_push_promise("/style.css") response = JSONResponse({"json": "OK"}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "OK"} def test_request_send_push_promise_without_push_extension( test_client_factory: TestClientFactory, ) -> None: """ If server does not support the `http.response.push` extension, .send_push_promise() does nothing. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope) await request.send_push_promise("/style.css") response = JSONResponse({"json": "OK"}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "OK"} def test_request_send_push_promise_without_setting_send( test_client_factory: TestClientFactory, ) -> None: """ If Request is instantiated without the send channel, then .send_push_promise() is not available. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: # the server is push-enabled scope["extensions"]["http.response.push"] = {} data = "OK" request = Request(scope) try: await request.send_push_promise("/style.css") except RuntimeError: data = "Send channel not available" response = JSONResponse({"json": data}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "Send channel not available"} @pytest.mark.parametrize( "messages", [ [{"body": b"123", "more_body": True}, {"body": b""}], [{"body": b"", "more_body": True}, {"body": b"123"}], [{"body": b"12", "more_body": True}, {"body": b"3"}], [ {"body": b"123", "more_body": True}, {"body": b"", "more_body": True}, {"body": b""}, ], ], ) @pytest.mark.anyio async def test_request_rcv(messages: list[Message]) -> None: messages = messages.copy() async def rcv() -> Message: return {"type": "http.request", **messages.pop(0)} request = Request({"type": "http"}, rcv) body = await request.body() assert body == b"123" @pytest.mark.anyio async def test_request_stream_called_twice() -> None: messages: list[Message] = [ {"type": "http.request", "body": b"1", "more_body": True}, {"type": "http.request", "body": b"2", "more_body": True}, {"type": "http.request", "body": b"3"}, ] async def rcv() -> Message: return messages.pop(0) request = Request({"type": "http"}, rcv) s1 = request.stream() s2 = request.stream() msg = await s1.__anext__() assert msg == b"1" msg = await s2.__anext__() assert msg == b"2" msg = await s1.__anext__() assert msg == b"3" # at this point we've consumed the entire body # so we should not wait for more body (which would hang us forever) msg = await s1.__anext__() assert msg == b"" msg = await s2.__anext__() assert msg == b"" # and now both streams are exhausted with pytest.raises(StopAsyncIteration): assert await s2.__anext__() with pytest.raises(StopAsyncIteration): await s1.__anext__() def test_request_url_outside_starlette_context(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) request.url_for("index") client = test_client_factory(app) with pytest.raises( RuntimeError, match="The `url_for` method can only be used inside a Starlette application or with a router.", ): client.get("/") def test_request_url_starlette_context(test_client_factory: TestClientFactory) -> None: from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.routing import Route from starlette.types import ASGIApp url_for = None async def homepage(request: Request) -> Response: return PlainTextResponse("Hello, world!") class CustomMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: nonlocal url_for request = Request(scope, receive) url_for = request.url_for("homepage") await self.app(scope, receive, send) app = Starlette(routes=[Route("/home", homepage)], middleware=[Middleware(CustomMiddleware)]) client = test_client_factory(app) client.get("/home") assert url_for == URL("http://testserver/home") ================================================ FILE: tests/test_responses.py ================================================ from __future__ import annotations import datetime as dt import sys import time from collections.abc import AsyncGenerator, AsyncIterator, Iterator from dataclasses import dataclass from http.cookies import SimpleCookie from pathlib import Path from typing import Any import anyio import pytest from python_multipart import MultipartParser from starlette import status from starlette.background import BackgroundTask from starlette.datastructures import Headers from starlette.requests import ClientDisconnect, Request from starlette.responses import FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send from tests.types import TestClientFactory def test_text_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "hello, world" def test_bytes_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.content == b"xxxxx" def test_json_none_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse(None) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() is None assert response.content == b"null" def test_redirect_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") else: response = RedirectResponse("/") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/" def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/I ♥ Starlette/": response = Response("hello, world", media_type="text/plain") else: response = RedirectResponse("/I ♥ Starlette/") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" def test_redirect_response_content_length_header( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello", media_type="text/plain") # pragma: no cover else: response = RedirectResponse("/") await response(scope, receive, send) client: TestClient = test_client_factory(app) response = client.request("GET", "/redirect", follow_redirects=False) assert response.url == "http://testserver/redirect" assert response.headers["content-length"] == "0" def test_streaming_response(test_client_factory: TestClientFactory) -> None: filled_by_bg_task = "" async def app(scope: Scope, receive: Receive, send: Send) -> None: async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " await anyio.sleep(0) async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: nonlocal filled_by_bg_task async for thing in numbers(start, stop): filled_by_bg_task = filled_by_bg_task + thing cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) generator = numbers(1, 5) response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task) await response(scope, receive, send) assert filled_by_bg_task == "" client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" def test_streaming_response_custom_iterator( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterator: def __init__(self) -> None: self._called = 0 def __aiter__(self) -> AsyncIterator[str]: return self async def __anext__(self) -> str: if self._called == 5: raise StopAsyncIteration() self._called += 1 return str(self._called) response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "12345" def test_streaming_response_custom_iterable( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterable: async def __aiter__(self) -> AsyncIterator[str | bytes]: for i in range(5): yield str(i + 1) response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "12345" def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: def numbers(minimum: int, maximum: int) -> Iterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " generator = numbers(1, 5) response = StreamingResponse(generator, media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" def test_response_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.headers["x-header-1"] == "123" assert response.headers["x-header-2"] == "789" def test_response_phrase(test_client_factory: TestClientFactory) -> None: app = Response(status_code=204) client = test_client_factory(app) response = client.get("/") assert response.reason_phrase == "No Content" app = Response(b"", status_code=123) client = test_client_factory(app) response = client.get("/") assert response.reason_phrase == "" def test_file_response(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) filled_by_bg_task = "" async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " await anyio.sleep(0) async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: nonlocal filled_by_bg_task async for thing in numbers(start, stop): filled_by_bg_task = filled_by_bg_task + thing cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = FileResponse(path=path, filename="example.png", background=cleanup_task) await response(scope, receive, send) assert filled_by_bg_task == "" client = test_client_factory(app) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK assert response.content == content assert response.headers["content-type"] == "image/png" assert response.headers["content-disposition"] == expected_disposition assert "content-length" in response.headers assert "last-modified" in response.headers assert "etag" in response.headers assert filled_by_bg_task == "6, 7, 8, 9" @pytest.mark.anyio async def test_file_response_on_head_method(tmp_path: Path) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) app = FileResponse(path=path, filename="example.png") async def receive() -> Message: # type: ignore[empty-body] ... # pragma: no cover async def send(message: Message) -> None: if message["type"] == "http.response.start": assert message["status"] == status.HTTP_200_OK headers = Headers(raw=message["headers"]) assert headers["content-type"] == "image/png" assert "content-length" in headers assert "content-disposition" in headers assert "last-modified" in headers assert "etag" in headers elif message["type"] == "http.response.body": # pragma: no branch assert message["body"] == b"" assert message["more_body"] is False # Since the TestClient drops the response body on HEAD requests, we need to test # this directly. await app({"type": "http", "method": "head", "headers": [(b"key", b"value")]}, receive, send) def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" path.write_bytes(b"") # By default, FileResponse will determine the `content-type` based on # the filename or path, unless a specific `media_type` is provided. app = FileResponse(path=path, filename="example.png", media_type="image/jpeg") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-type"] == "image/jpeg" def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: app = FileResponse(path=tmp_path, filename="example.png") client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "is not a file" in str(exc_info.value) def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "404.txt" app = FileResponse(path=path, filename="404.txt") client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "does not exist" in str(exc_info.value) def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = tmp_path / filename path.write_bytes(content) app = FileResponse(path=path, filename=filename) client = test_client_factory(app) response = client.get("/") expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt" assert response.status_code == status.HTTP_200_OK assert response.content == content assert response.headers["content-disposition"] == expected_disposition def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "hello.txt" path = tmp_path / filename path.write_bytes(content) app = FileResponse(path=path, filename=filename, content_disposition_type="inline") client = test_client_factory(app) response = client.get("/") expected_disposition = 'inline; filename="hello.txt"' assert response.status_code == status.HTTP_200_OK assert response.content == content assert response.headers["content-disposition"] == expected_disposition def test_file_response_with_range_header(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "hello.txt" path = tmp_path / filename path.write_bytes(content) etag = '"a_non_autogenerated_etag"' app = FileResponse(path=path, filename=filename, headers={"etag": etag}) client = test_client_factory(app) response = client.get("/", headers={"range": "bytes=0-4", "if-range": etag}) assert response.status_code == status.HTTP_206_PARTIAL_CONTENT assert response.content == content[:5] assert response.headers["etag"] == etag assert response.headers["content-length"] == "5" assert response.headers["content-range"] == f"bytes 0-4/{len(content)}" @pytest.mark.anyio async def test_file_response_with_pathsend(tmpdir: Path) -> None: path = tmpdir / "xyz" content = b"" * 1000 with open(path, "wb") as file: file.write(content) app = FileResponse(path=path, filename="example.png") async def receive() -> Message: # type: ignore[empty-body] ... # pragma: no cover async def send(message: Message) -> None: if message["type"] == "http.response.start": assert message["status"] == status.HTTP_200_OK headers = Headers(raw=message["headers"]) assert headers["content-type"] == "image/png" assert "content-length" in headers assert "content-disposition" in headers assert "last-modified" in headers assert "etag" in headers elif message["type"] == "http.response.pathsend": # pragma: no branch assert message["path"] == str(path) # Since the TestClient doesn't support `pathsend`, we need to test this directly. await app( {"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}}, receive, send, ) def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie( "mycookie", "myvalue", max_age=10, expires=10, path="/", domain="localhost", secure=True, httponly=True, samesite="none", partitioned=True if sys.version_info >= (3, 14) else False, ) await response(scope, receive, send) partitioned_text = "Partitioned; " if sys.version_info >= (3, 14) else "" client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" assert ( response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; " f"HttpOnly; Max-Age=10; {partitioned_text}Path=/; SameSite=none; Secure" ) @pytest.mark.skipif(sys.version_info >= (3, 14), reason="Only relevant for <3.14") def test_set_cookie_raises_for_invalid_python_version( test_client_factory: TestClientFactory, ) -> None: # pragma: no cover async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") with pytest.raises(ValueError): response.set_cookie("mycookie", "myvalue", partitioned=True) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" assert response.headers.get("set-cookie") is None def test_set_cookie_path_none(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", path=None) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" assert response.headers["set-cookie"] == "mycookie=myvalue; SameSite=lax" def test_set_cookie_samesite_none(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", samesite=None) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" assert response.headers["set-cookie"] == "mycookie=myvalue; Path=/" @pytest.mark.parametrize( "expires", [ pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"), pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"), pytest.param(10, id="int"), ], ) def test_expires_on_set_cookie( test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch, expires: str, ) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", expires=expires) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") cookie = SimpleCookie(response.headers.get("set-cookie")) assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT" def test_delete_cookie(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") if request.cookies.get("mycookie"): response.delete_cookie("mycookie") else: response.set_cookie("mycookie", "myvalue") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.cookies["mycookie"] response = client.get("/") assert not response.cookies.get("mycookie") def test_populate_headers(test_client_factory: TestClientFactory) -> None: app = Response(content="hi", headers={}, media_type="text/html") client = test_client_factory(app) response = client.get("/") assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" def test_head_method(test_client_factory: TestClientFactory) -> None: app = Response("hello, world", media_type="text/plain") client = test_client_factory(app) response = client.head("/") assert response.text == "" def test_empty_response(test_client_factory: TestClientFactory) -> None: app = Response() client: TestClient = test_client_factory(app) response = client.get("/") assert response.content == b"" assert response.headers["content-length"] == "0" assert "content-type" not in response.headers def test_empty_204_response(test_client_factory: TestClientFactory) -> None: app = Response(status_code=204) client: TestClient = test_client_factory(app) response = client.get("/") assert "content-length" not in response.headers def test_non_empty_response(test_client_factory: TestClientFactory) -> None: app = Response(content="hi") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "2" def test_response_do_not_add_redundant_charset( test_client_factory: TestClientFactory, ) -> None: app = Response(media_type="text/plain; charset=utf-8") client = test_client_factory(app) response = client.get("/") assert response.headers["content-type"] == "text/plain; charset=utf-8" def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) app = FileResponse(path=path, filename="example.png") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == str(len(content)) def test_streaming_response_unknown_size( test_client_factory: TestClientFactory, ) -> None: app = StreamingResponse(content=iter(["hello", "world"])) client: TestClient = test_client_factory(app) response = client.get("/") assert "content-length" not in response.headers def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None: app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"}) client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "10" def test_response_memoryview(test_client_factory: TestClientFactory) -> None: app = Response(content=memoryview(b"\xc0")) client: TestClient = test_client_factory(app) response = client.get("/") assert response.content == b"\xc0" def test_streaming_response_memoryview(test_client_factory: TestClientFactory) -> None: app = StreamingResponse(content=iter([memoryview(b"\xc0"), memoryview(b"\xf5")])) client: TestClient = test_client_factory(app) response = client.get("/") assert response.content == b"\xc0\xf5" @pytest.mark.anyio async def test_streaming_response_stops_if_receiving_http_disconnect() -> None: streamed = 0 disconnected = anyio.Event() async def receive_disconnect() -> Message: await disconnected.wait() return {"type": "http.disconnect"} async def send(message: Message) -> None: nonlocal streamed if message["type"] == "http.response.body": streamed += len(message.get("body", b"")) # Simulate disconnection after download has started if streamed >= 16: disconnected.set() async def stream_indefinitely() -> AsyncIterator[bytes]: while True: # Need a sleep for the event loop to switch to another task await anyio.sleep(0) yield b"chunk " response = StreamingResponse(content=stream_indefinitely()) with anyio.move_on_after(1) as cancel_scope: await response({"type": "http"}, receive_disconnect, send) assert not cancel_scope.cancel_called, "Content streaming should stop itself." @pytest.mark.anyio async def test_streaming_response_on_client_disconnects() -> None: chunks = bytearray() streamed = False async def receive_disconnect() -> Message: raise NotImplementedError async def send(message: Message) -> None: nonlocal streamed if message["type"] == "http.response.body": if not streamed: chunks.extend(message.get("body", b"")) streamed = True else: raise OSError async def stream_indefinitely() -> AsyncGenerator[bytes, None]: while True: await anyio.sleep(0) yield b"chunk" stream = stream_indefinitely() response = StreamingResponse(content=stream) with anyio.move_on_after(1) as cancel_scope: with pytest.raises(ClientDisconnect): await response({"type": "http", "asgi": {"spec_version": "2.4"}}, receive_disconnect, send) assert not cancel_scope.cancel_called, "Content streaming should stop itself." assert chunks == b"chunk" await stream.aclose() @pytest.mark.anyio async def test_streaming_response_runs_background_on_websocket_scope() -> None: background_called = False sent: list[Message] = [] async def receive() -> Message: return {} # pragma: no cover async def send(message: Message) -> None: sent.append(message) def run_background() -> None: nonlocal background_called background_called = True async def stream() -> AsyncIterator[bytes]: yield b"chunk" response = StreamingResponse(stream(), background=BackgroundTask(run_background)) await response({"type": "websocket"}, receive, send) assert background_called assert [message["type"] for message in sent] == [ "websocket.http.response.start", "websocket.http.response.body", "websocket.http.response.body", ] README = """\ # BáiZé Powerful and exquisite WSGI/ASGI framework/toolkit. The minimize implementation of methods required in the Web framework. No redundant implementation means that you can freely customize functions without considering the conflict with baize's own implementation. Under the ASGI/WSGI protocol, the interface of the request object and the response object is almost the same, only need to add or delete `await` in the appropriate place. In addition, it should be noted that ASGI supports WebSocket but WSGI does not. """ # noqa: E501 @pytest.fixture def readme_file(tmp_path: Path) -> Path: filepath = tmp_path / "README.txt" filepath.write_bytes(README.encode("utf8")) return filepath @pytest.fixture def file_response_client(readme_file: Path, test_client_factory: TestClientFactory) -> TestClient: return test_client_factory(app=FileResponse(str(readme_file))) def test_file_response_without_range(file_response_client: TestClient) -> None: response = file_response_client.get("/") assert response.status_code == 200 assert "content-range" not in response.headers assert response.headers["content-length"] == str(len(README.encode("utf8"))) assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.text == README def test_file_response_head(file_response_client: TestClient) -> None: response = file_response_client.head("/") assert response.status_code == 200 assert "content-range" not in response.headers assert response.headers["content-length"] == str(len(README.encode("utf8"))) assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.content == b"" def test_file_response_range(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=0-100"}) assert response.status_code == 206 assert response.headers["content-range"] == f"bytes 0-100/{len(README.encode('utf8'))}" assert response.headers["content-length"] == "101" assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.content == README.encode("utf8")[:101] def test_file_response_range_head(file_response_client: TestClient) -> None: response = file_response_client.head("/", headers={"Range": "bytes=0-100"}) assert response.status_code == 206 assert response.headers["content-range"] == f"bytes 0-100/{len(README.encode('utf8'))}" assert response.headers["content-length"] == str(101) assert response.headers["content-type"] == "text/plain; charset=utf-8" assert response.content == b"" def test_file_response_range_multi(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=0-100, 200-300"}) assert response.status_code == 206 assert "content-range" not in response.headers assert response.headers["content-length"] == "448" assert response.headers["content-type"].startswith("multipart/byteranges; boundary=") def test_file_response_range_multi_head(file_response_client: TestClient) -> None: response = file_response_client.head("/", headers={"Range": "bytes=0-100, 200-300"}) assert response.status_code == 206 assert "content-range" not in response.headers assert response.headers["content-length"] == "448" assert response.headers["content-type"].startswith("multipart/byteranges; boundary=") assert response.content == b"" response = file_response_client.head( "/", headers={"Range": "bytes=200-300", "if-range": response.headers["etag"][:-1]}, ) assert response.status_code == 200 response = file_response_client.head( "/", headers={"Range": "bytes=200-300", "if-range": response.headers["etag"]}, ) assert response.status_code == 206 def test_file_response_range_invalid(file_response_client: TestClient) -> None: response = file_response_client.head("/", headers={"Range": "bytes: 0-1000"}) assert response.status_code == 400 def test_file_response_range_head_max(file_response_client: TestClient) -> None: response = file_response_client.head("/", headers={"Range": f"bytes=0-{len(README.encode('utf8')) + 1}"}) assert response.status_code == 206 def test_file_response_range_416(file_response_client: TestClient) -> None: response = file_response_client.head("/", headers={"Range": f"bytes={len(README.encode('utf8')) + 1}-"}) assert response.status_code == 416 assert response.headers["Content-Range"] == f"bytes */{len(README.encode('utf8'))}" def test_file_response_only_support_bytes_range(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "items=0-100"}) assert response.status_code == 400 assert response.text == "Only support bytes range" def test_file_response_range_must_be_requested(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes="}) assert response.status_code == 400 assert response.text == "Range header: range must be requested" def test_file_response_start_must_be_less_than_end(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=100-0"}) assert response.status_code == 400 assert response.text == "Range header: start must be less than end" def test_file_response_merge_ranges(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=0-100, 50-200"}) assert response.status_code == 206 assert response.headers["content-length"] == "201" assert response.headers["content-range"] == f"bytes 0-200/{len(README.encode('utf8'))}" @dataclass class MultipartPart: headers: dict[bytes, bytes] data: bytes def parse_multipart_data(data: bytes, boundary: bytes | str) -> list[MultipartPart]: parts: list[MultipartPart] = [] done = False current_headers: dict[bytes, bytes] = {} current_header_field: bytes = b"" def on_part_begin() -> None: nonlocal current_headers current_headers = {} def on_part_data(data: bytes, start: int, end: int) -> None: parts.append(MultipartPart(current_headers, data[start:end])) def on_header_field(data: bytes, start: int, end: int) -> None: nonlocal current_header_field current_header_field = data[start:end] def on_header_value(data: bytes, start: int, end: int) -> None: current_headers[current_header_field] = data[start:end] def on_end() -> None: nonlocal done done = True parser = MultipartParser( boundary, dict( on_part_begin=on_part_begin, on_part_data=on_part_data, on_header_field=on_header_field, on_header_value=on_header_value, on_end=on_end, ), ) parser.write(data) parser.finalize() assert done return parts def test_file_response_insert_ranges(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=100-200, 0-50"}) assert response.status_code == 206 assert "content-range" not in response.headers assert response.headers["content-type"].startswith("multipart/byteranges; boundary=") boundary = response.headers["content-type"].split("boundary=")[1] assert response.text.splitlines() == [ f"--{boundary}", "Content-Type: text/plain; charset=utf-8", "Content-Range: bytes 0-50/526", "", "# BáiZé", "", "Powerful and exquisite WSGI/ASGI framewo", f"--{boundary}", "Content-Type: text/plain; charset=utf-8", "Content-Range: bytes 100-200/526", "", "ds required in the Web framework. No redundant implementation means that you can freely customize fun", f"--{boundary}--", ] parts = parse_multipart_data(response._content, boundary) assert all( value == b"text/plain; charset=utf-8" for part in parts for key, value in part.headers.items() if key == b"Content-Type" ) assert len(parts) == 2 assert parts[0].headers[b"Content-Range"] == b"bytes 0-50/526" assert parts[0].data == "# BáiZé\n\nPowerful and exquisite WSGI/ASGI framewo".encode() assert parts[1].headers[b"Content-Range"] == b"bytes 100-200/526" assert ( parts[1].data == b"ds required in the Web framework. No redundant implementation means that you can freely customize fun" ) def test_file_response_range_without_dash(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=100, 0-50"}) assert response.status_code == 206 assert response.headers["content-range"] == f"bytes 0-50/{len(README.encode('utf8'))}" def test_file_response_range_empty_start_and_end(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes= - , 0-50"}) assert response.status_code == 206 assert response.headers["content-range"] == f"bytes 0-50/{len(README.encode('utf8'))}" def test_file_response_range_ignore_non_numeric(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=abc-def, 0-50"}) assert response.status_code == 206 assert response.headers["content-range"] == f"bytes 0-50/{len(README.encode('utf8'))}" def test_file_response_suffix_range(file_response_client: TestClient) -> None: # Test suffix range (last N bytes) - line 523 with empty start_str response = file_response_client.get("/", headers={"Range": "bytes=-100"}) assert response.status_code == 206 file_size = len(README.encode("utf8")) assert response.headers["content-range"] == f"bytes {file_size - 100}-{file_size - 1}/{file_size}" assert response.headers["content-length"] == "100" assert response.content == README.encode("utf8")[-100:] def test_file_response_multiple_calls(file_response_client: TestClient) -> None: response = file_response_client.get("/", headers={"Range": "bytes=0-100"}) assert response.status_code == 206 response = file_response_client.get("/") assert response.status_code == 200 assert "content-range" not in response.headers assert response.headers["content-length"] == str(len(README.encode("utf8"))) assert response.headers["content-type"] == "text/plain; charset=utf-8" @pytest.mark.anyio async def test_file_response_multi_small_chunk_size(readme_file: Path) -> None: class SmallChunkSizeFileResponse(FileResponse): chunk_size = 10 app = SmallChunkSizeFileResponse(path=str(readme_file)) received_chunks: list[bytes] = [] start_message: dict[str, Any] = {} async def receive() -> Message: raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.start": start_message.update(message) elif message["type"] == "http.response.body": # pragma: no branch received_chunks.append(message["body"]) await app({"type": "http", "method": "get", "headers": [(b"range", b"bytes=0-15,20-35,35-50")]}, receive, send) assert start_message["status"] == 206 headers = Headers(raw=start_message["headers"]) assert "content-range" not in headers assert headers.get("accept-ranges") == "bytes" assert "content-length" in headers assert "last-modified" in headers assert "etag" in headers assert headers["content-type"].startswith("multipart/byteranges; boundary=") boundary = headers["content-type"].split("boundary=")[1] assert received_chunks == [ # Send the part headers. f"--{boundary}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Range: bytes 0-15/526\r\n\r\n".encode(), # Send the first chunk (10 bytes). b"# B\xc3\xa1iZ\xc3\xa9\n", # Send the second chunk (6 bytes). b"\nPower", # Send the new line to separate the parts. b"\r\n", # Send the part headers. We merge the ranges 20-35 and 35-50 into a single part. f"--{boundary}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Range: bytes 20-50/526\r\n\r\n".encode(), # Send the first chunk (10 bytes). b"and exquis", # Send the second chunk (10 bytes). b"ite WSGI/A", # Send the third chunk (10 bytes). b"SGI framew", # Send the last chunk (1 byte). b"o", b"\r\n", f"--{boundary}--".encode(), ] ================================================ FILE: tests/test_routing.py ================================================ from __future__ import annotations import contextlib import functools import json import uuid from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator from typing import TypedDict import pytest from typing_extensions import Never from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect from tests.types import TestClientFactory def homepage(request: Request) -> Response: return Response("Hello, world", media_type="text/plain") def users(request: Request) -> Response: return Response("All users", media_type="text/plain") def user(request: Request) -> Response: content = "User " + request.path_params["username"] return Response(content, media_type="text/plain") def user_me(request: Request) -> Response: content = "User fixed me" return Response(content, media_type="text/plain") def disable_user(request: Request) -> Response: content = "User " + request.path_params["username"] + " disabled" return Response(content, media_type="text/plain") def user_no_match(request: Request) -> Response: # pragma: no cover content = "User fixed no match" return Response(content, media_type="text/plain") async def partial_endpoint(arg: str, request: Request) -> JSONResponse: return JSONResponse({"arg": arg}) async def partial_ws_endpoint(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() class PartialRoutes: @classmethod async def async_endpoint(cls, arg: str, request: Request) -> JSONResponse: return JSONResponse({"arg": arg}) @classmethod async def async_ws_endpoint(cls, websocket: WebSocket) -> None: await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() def func_homepage(request: Request) -> Response: return Response("Hello, world!", media_type="text/plain") def contact(request: Request) -> Response: return Response("Hello, POST!", media_type="text/plain") def int_convertor(request: Request) -> JSONResponse: number = request.path_params["param"] return JSONResponse({"int": number}) def float_convertor(request: Request) -> JSONResponse: num = request.path_params["param"] return JSONResponse({"float": num}) def path_convertor(request: Request) -> JSONResponse: path = request.path_params["param"] return JSONResponse({"path": path}) def uuid_converter(request: Request) -> JSONResponse: uuid_param = request.path_params["param"] return JSONResponse({"uuid": str(uuid_param)}) def path_with_parentheses(request: Request) -> JSONResponse: number = request.path_params["param"] return JSONResponse({"int": number}) async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() async def websocket_params(session: WebSocket) -> None: await session.accept() await session.send_text(f"Hello, {session.path_params['room']}!") await session.close() app = Router( [ Route("/", endpoint=homepage, methods=["GET"]), Mount( "/users", routes=[ Route("/", endpoint=users), Route("/me", endpoint=user_me), Route("/{username}", endpoint=user), Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]), Route("/nomatch", endpoint=user_no_match), ], ), Mount( "/partial", routes=[ Route("/", endpoint=functools.partial(partial_endpoint, "foo")), Route( "/cls", endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"), ), WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)), WebSocketRoute( "/ws/cls", endpoint=functools.partial(PartialRoutes.async_ws_endpoint), ), ], ), Mount("/static", app=Response("xxxxx", media_type="image/png")), Route("/func", endpoint=func_homepage, methods=["GET"]), Route("/func", endpoint=contact, methods=["POST"]), Route("/int/{param:int}", endpoint=int_convertor, name="int-convertor"), Route("/float/{param:float}", endpoint=float_convertor, name="float-convertor"), Route("/path/{param:path}", endpoint=path_convertor, name="path-convertor"), Route("/uuid/{param:uuid}", endpoint=uuid_converter, name="uuid-convertor"), # Route with chars that conflict with regex meta chars Route( "/path-with-parentheses({param:int})", endpoint=path_with_parentheses, name="path-with-parentheses", ), WebSocketRoute("/ws", endpoint=websocket_endpoint), WebSocketRoute("/ws/{room}", endpoint=websocket_params), ] ) @pytest.fixture def client( test_client_factory: TestClientFactory, ) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client @pytest.mark.filterwarnings( r"ignore" r":Trying to detect encoding from a tiny portion of \(5\) byte\(s\)\." r":UserWarning" r":charset_normalizer.api" ) def test_router(client: TestClient) -> None: response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world" response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"} response = client.get("/foo") assert response.status_code == 404 assert response.text == "Not Found" response = client.get("/users") assert response.status_code == 200 assert response.text == "All users" response = client.get("/users/tomchristie") assert response.status_code == 200 assert response.text == "User tomchristie" response = client.get("/users/me") assert response.status_code == 200 assert response.text == "User fixed me" response = client.get("/users/tomchristie/") assert response.status_code == 200 assert response.url == "http://testserver/users/tomchristie" assert response.text == "User tomchristie" response = client.put("/users/tomchristie:disable") assert response.status_code == 200 assert response.url == "http://testserver/users/tomchristie:disable" assert response.text == "User tomchristie disabled" response = client.get("/users/nomatch") assert response.status_code == 200 assert response.text == "User nomatch" response = client.get("/static/123") assert response.status_code == 200 assert response.text == "xxxxx" def test_route_converters(client: TestClient) -> None: # Test integer conversion response = client.get("/int/5") assert response.status_code == 200 assert response.json() == {"int": 5} assert app.url_path_for("int-convertor", param=5) == "/int/5" # Test path with parentheses response = client.get("/path-with-parentheses(7)") assert response.status_code == 200 assert response.json() == {"int": 7} assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)" # Test float conversion response = client.get("/float/25.5") assert response.status_code == 200 assert response.json() == {"float": 25.5} assert app.url_path_for("float-convertor", param=25.5) == "/float/25.5" # Test path conversion response = client.get("/path/some/example") assert response.status_code == 200 assert response.json() == {"path": "some/example"} assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example" # Test UUID conversion response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a") assert response.status_code == 200 assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"} assert ( app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")) == "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a" ) def test_url_path_for() -> None: assert app.url_path_for("homepage") == "/" assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie" assert app.url_path_for("websocket_endpoint") == "/ws" with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'): assert app.url_path_for("broken") with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'): assert app.url_path_for("broken", key="value", key2="value2") with pytest.raises(AssertionError): app.url_path_for("user", username="tom/christie") with pytest.raises(AssertionError): app.url_path_for("user", username="") def test_url_for() -> None: assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/" assert ( app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/") == "https://example.org/root_path/" ) assert ( app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org") == "https://example.org/users/tomchristie" ) assert ( app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/") == "https://example.org/root_path/users/tomchristie" ) assert ( app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org") == "wss://example.org/ws" ) def test_router_add_route(client: TestClient) -> None: response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" def test_router_duplicate_path(client: TestClient) -> None: response = client.post("/func") assert response.status_code == 200 assert response.text == "Hello, POST!" def test_router_add_websocket_route(client: TestClient) -> None: with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" with client.websocket_connect("/ws/test") as session: text = session.receive_text() assert text == "Hello, test!" def test_router_middleware(test_client_factory: TestClientFactory) -> None: class CustomMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response = PlainTextResponse("OK") await response(scope, receive, send) app = Router( routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)], ) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "OK" def http_endpoint(request: Request) -> Response: url = request.url_for("http_endpoint") return Response(f"URL: {url}", media_type="text/plain") class WebSocketEndpoint: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope=scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"URL": str(websocket.url_for("websocket_endpoint"))}) await websocket.close() mixed_protocol_app = Router( routes=[ Route("/", endpoint=http_endpoint), WebSocketRoute("/", endpoint=WebSocketEndpoint(), name="websocket_endpoint"), ] ) def test_protocol_switch(test_client_factory: TestClientFactory) -> None: client = test_client_factory(mixed_protocol_app) response = client.get("/") assert response.status_code == 200 assert response.text == "URL: http://testserver/" with client.websocket_connect("/") as session: assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/404"): pass # pragma: no cover ok = PlainTextResponse("OK") def test_mount_urls(test_client_factory: TestClientFactory) -> None: mounted = Router([Mount("/users", ok, name="users")]) client = test_client_factory(mounted) assert client.get("/users").status_code == 200 assert client.get("/users").url == "http://testserver/users/" assert client.get("/users/").status_code == 200 assert client.get("/users/a").status_code == 200 assert client.get("/usersa").status_code == 404 def test_reverse_mount_urls() -> None: mounted = Router([Mount("/users", ok, name="users")]) assert mounted.url_path_for("users", path="/a") == "/users/a" users = Router([Route("/{username}", ok, name="user")]) mounted = Router([Mount("/{subpath}/users", users, name="users")]) assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom" assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom" mounted = Router([Mount("/users", ok, name="users")]) with pytest.raises(NoMatchFound): mounted.url_path_for("users", path="/a", foo="bar") mounted = Router([Mount("/users", ok, name="users")]) with pytest.raises(NoMatchFound): mounted.url_path_for("users") def test_mount_at_root(test_client_factory: TestClientFactory) -> None: mounted = Router([Mount("/", ok, name="users")]) client = test_client_factory(mounted) assert client.get("/").status_code == 200 def users_api(request: Request) -> JSONResponse: return JSONResponse({"users": [{"username": "tom"}]}) mixed_hosts_app = Router( routes=[ Host( "www.example.org", app=Router( [ Route("/", homepage, name="homepage"), Route("/users", users, name="users"), ] ), ), Host( "api.example.org", name="api", app=Router([Route("/users", users_api, name="users")]), ), Host( "port.example.org:3600", name="port", app=Router([Route("/", homepage, name="homepage")]), ), ] ) def test_host_routing(test_client_factory: TestClientFactory) -> None: client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/") response = client.get("/users") assert response.status_code == 200 assert response.json() == {"users": [{"username": "tom"}]} response = client.get("/") assert response.status_code == 404 client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/") response = client.get("/users") assert response.status_code == 200 assert response.text == "All users" response = client.get("/") assert response.status_code == 200 client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/") response = client.get("/users") assert response.status_code == 404 response = client.get("/") assert response.status_code == 200 # Port in requested Host is irrelevant. client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/") response = client.get("/") assert response.status_code == 200 client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/") response = client.get("/") assert response.status_code == 200 def test_host_reverse_urls() -> None: assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/" assert ( mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users" ) assert ( mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever") == "https://api.example.org/users" ) assert ( mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever") == "https://port.example.org:3600/" ) with pytest.raises(NoMatchFound): mixed_hosts_app.url_path_for("api", path="whatever", foo="bar") async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse({"subdomain": scope["path_params"]["subdomain"]}) await response(scope, receive, send) subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]) def test_subdomain_routing(test_client_factory: TestClientFactory) -> None: client = test_client_factory(subdomain_router, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 assert response.json() == {"subdomain": "foo"} def test_subdomain_reverse_urls() -> None: assert ( subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url( "https://whatever" ) == "https://foo.example.org/homepage" ) async def echo_urls(request: Request) -> JSONResponse: return JSONResponse( { "index": str(request.url_for("index")), "submount": str(request.url_for("mount:submount")), } ) echo_url_routes = [ Route("/", echo_urls, name="index", methods=["GET"]), Mount( "/submount", name="mount", routes=[Route("/", echo_urls, name="submount", methods=["GET"])], ), ] def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None: app = Starlette(routes=echo_url_routes) client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path") response = client.get("/sub_path/") assert response.json() == { "index": "https://www.example.org/sub_path/", "submount": "https://www.example.org/sub_path/submount/", } response = client.get("/sub_path/submount/") assert response.json() == { "index": "https://www.example.org/sub_path/", "submount": "https://www.example.org/sub_path/submount/", } async def stub_app(scope: Scope, receive: Receive, send: Send) -> None: pass # pragma: no cover double_mount_routes = [ Mount("/mount", name="mount", routes=[Mount("/static", stub_app, name="static")]), ] def test_url_for_with_double_mount() -> None: app = Starlette(routes=double_mount_routes) url = app.url_path_for("mount:static", path="123") assert url == "/mount/static/123" def test_url_for_with_root_path_ending_with_slash(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> JSONResponse: return JSONResponse({"index": str(request.url_for("homepage"))}) app = Starlette(routes=[Route("/", homepage, name="homepage")]) client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path/") response = client.get("/sub_path/") assert response.json() == {"index": "https://www.example.org/sub_path/"} def test_standalone_route_matches( test_client_factory: TestClientFactory, ) -> None: app = Route("/", PlainTextResponse("Hello, World!")) client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" def test_standalone_route_does_not_match( test_client_factory: Callable[..., TestClient], ) -> None: app = Route("/", PlainTextResponse("Hello, World!")) client = test_client_factory(app) response = client.get("/invalid") assert response.status_code == 404 assert response.text == "Not Found" async def ws_helloworld(websocket: WebSocket) -> None: await websocket.accept() await websocket.send_text("Hello, world!") await websocket.close() def test_standalone_ws_route_matches( test_client_factory: TestClientFactory, ) -> None: app = WebSocketRoute("/", ws_helloworld) client = test_client_factory(app) with client.websocket_connect("/") as websocket: text = websocket.receive_text() assert text == "Hello, world!" def test_standalone_ws_route_does_not_match( test_client_factory: TestClientFactory, ) -> None: app = WebSocketRoute("/", ws_helloworld) client = test_client_factory(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/invalid"): pass # pragma: no cover def test_lifespan_state_unsupported(test_client_factory: TestClientFactory) -> None: @contextlib.asynccontextmanager async def lifespan(app: ASGIApp) -> AsyncGenerator[dict[str, str], None]: yield {"foo": "bar"} app = Router( lifespan=lifespan, routes=[Mount("/", PlainTextResponse("hello, world"))], ) async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None: del scope["state"] await app(scope, receive, send) with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'): with test_client_factory(no_state_wrapper): raise AssertionError("Should not be called") # pragma: no cover def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None: startup_complete = False shutdown_complete = False class State(TypedDict): count: int items: list[int] async def hello_world(request: Request) -> Response: # modifications to the state should not leak across requests assert request.state.count == 0 # modify the state, this should not leak to the lifespan or other requests request.state.count += 1 # since state.items is a mutable object this modification _will_ leak across # requests and to the lifespan request.state.items.append(1) return PlainTextResponse("hello, world") @contextlib.asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[State]: nonlocal startup_complete, shutdown_complete startup_complete = True state = State(count=0, items=[]) yield state shutdown_complete = True # modifications made to the state from a request do not leak to the lifespan assert state["count"] == 0 # unless of course the request mutates a mutable object that is referenced # via state assert state["items"] == [1, 1] app = Router( lifespan=lifespan, routes=[Route("/", hello_world)], ) assert not startup_complete assert not shutdown_complete with test_client_factory(app) as client: assert startup_complete assert not shutdown_complete client.get("/") # Calling it a second time to ensure that the state is preserved. client.get("/") assert startup_complete assert shutdown_complete def test_raise_on_startup(test_client_factory: TestClientFactory) -> None: @contextlib.asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[Never]: raise RuntimeError() yield # pragma: no cover router = Router(lifespan=lifespan) startup_failed = False async def app(scope: Scope, receive: Receive, send: Send) -> None: async def _send(message: Message) -> None: nonlocal startup_failed if message["type"] == "lifespan.startup.failed": # pragma: no branch startup_failed = True return await send(message) await router(scope, receive, _send) with pytest.raises(RuntimeError): with test_client_factory(app): pass # pragma: no cover assert startup_failed def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None: @contextlib.asynccontextmanager async def lifespan(app: Starlette) -> AsyncIterator[None]: yield raise RuntimeError("Shutdown failed") app = Router(lifespan=lifespan) with pytest.raises(RuntimeError, match="Shutdown failed"): with test_client_factory(app): pass # pragma: no cover def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None: test_client = test_client_factory(app) response = test_client.get("/partial") assert response.status_code == 200 assert response.json() == {"arg": "foo"} cls_method_response = test_client.get("/partial/cls") assert cls_method_response.status_code == 200 assert cls_method_response.json() == {"arg": "foo"} def test_partial_async_ws_endpoint( test_client_factory: TestClientFactory, ) -> None: test_client = test_client_factory(app) with test_client.websocket_connect("/partial/ws") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/partial/ws"} with test_client.websocket_connect("/partial/ws/cls") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/partial/ws/cls"} def test_duplicated_param_names() -> None: with pytest.raises( ValueError, match="Duplicated param name id at path /{id}/{id}", ): Route("/{id}/{id}", user) with pytest.raises( ValueError, match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}", ): Route("/{id}/{name}/{id}/{name}", user) class Endpoint: async def my_method(self, request: Request) -> None: ... # pragma: no cover @classmethod async def my_classmethod(cls, request: Request) -> None: ... # pragma: no cover @staticmethod async def my_staticmethod(request: Request) -> None: ... # pragma: no cover def __call__(self, request: Request) -> None: ... # pragma: no cover @pytest.mark.parametrize( "endpoint, expected_name", [ pytest.param(func_homepage, "func_homepage", id="function"), pytest.param(Endpoint().my_method, "my_method", id="method"), pytest.param(Endpoint.my_classmethod, "my_classmethod", id="classmethod"), pytest.param( Endpoint.my_staticmethod, "my_staticmethod", id="staticmethod", ), pytest.param(Endpoint(), "Endpoint", id="object"), pytest.param(lambda request: ..., "", id="lambda"), # pragma: no branch ], ) def test_route_name(endpoint: Callable[..., Response], expected_name: str) -> None: assert Route(path="/", endpoint=endpoint).name == expected_name class AddHeadersMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["add_headers_middleware"] = True async def modified_send(msg: Message) -> None: if msg["type"] == "http.response.start": msg["headers"].append((b"X-Test", b"Set by middleware")) await send(msg) await self.app(scope, receive, modified_send) def assert_middleware_header_route(request: Request) -> Response: assert request.scope["add_headers_middleware"] is True return Response() route_with_middleware = Starlette( routes=[ Route( "/http", endpoint=assert_middleware_header_route, methods=["GET"], middleware=[Middleware(AddHeadersMiddleware)], ), Route("/home", homepage), ] ) mounted_routes_with_middleware = Starlette( routes=[ Mount( "/http", routes=[ Route( "/", endpoint=assert_middleware_header_route, methods=["GET"], name="route", ), ], middleware=[Middleware(AddHeadersMiddleware)], ), Route("/home", homepage), ] ) mounted_app_with_middleware = Starlette( routes=[ Mount( "/http", app=Route( "/", endpoint=assert_middleware_header_route, methods=["GET"], name="route", ), middleware=[Middleware(AddHeadersMiddleware)], ), Route("/home", homepage), ] ) @pytest.mark.parametrize( "app", [ mounted_routes_with_middleware, mounted_app_with_middleware, route_with_middleware, ], ) def test_base_route_middleware( test_client_factory: TestClientFactory, app: Starlette, ) -> None: test_client = test_client_factory(app) response = test_client.get("/home") assert response.status_code == 200 assert "X-Test" not in response.headers response = test_client.get("/http") assert response.status_code == 200 assert response.headers["X-Test"] == "Set by middleware" def test_mount_routes_with_middleware_url_path_for() -> None: """Checks that url_path_for still works with mounted routes with Middleware""" assert mounted_routes_with_middleware.url_path_for("route") == "/http/" def test_mount_asgi_app_with_middleware_url_path_for() -> None: """Mounted ASGI apps do not work with url path for, middleware does not change this """ with pytest.raises(NoMatchFound): mounted_app_with_middleware.url_path_for("route") def test_add_route_to_app_after_mount( test_client_factory: Callable[..., TestClient], ) -> None: """Checks that Mount will pick up routes added to the underlying app after it is mounted """ inner_app = Router() app = Mount("/http", app=inner_app) inner_app.add_route( "/inner", endpoint=homepage, methods=["GET"], ) client = test_client_factory(app) response = client.get("/http/inner") assert response.status_code == 200 def test_exception_on_mounted_apps( test_client_factory: TestClientFactory, ) -> None: def exc(request: Request) -> None: raise Exception("Exc") sub_app = Starlette(routes=[Route("/", exc)]) app = Starlette(routes=[Mount("/sub", app=sub_app)]) client = test_client_factory(app) with pytest.raises(Exception) as ctx: client.get("/sub/") assert str(ctx.value) == "Exc" def test_mounted_middleware_does_not_catch_exception( test_client_factory: Callable[..., TestClient], ) -> None: # https://github.com/Kludex/starlette/pull/1649#discussion_r960236107 def exc(request: Request) -> Response: raise HTTPException(status_code=403, detail="auth") class NamedMiddleware: def __init__(self, app: ASGIApp, name: str) -> None: self.app = app self.name = name async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def modified_send(msg: Message) -> None: if msg["type"] == "http.response.start": msg["headers"].append((f"X-{self.name}".encode(), b"true")) await send(msg) await self.app(scope, receive, modified_send) app = Starlette( routes=[ Mount( "/mount", routes=[ Route("/err", exc), Route("/home", homepage), ], middleware=[Middleware(NamedMiddleware, name="Mounted")], ), Route("/err", exc), Route("/home", homepage), ], middleware=[Middleware(NamedMiddleware, name="Outer")], ) client = test_client_factory(app) resp = client.get("/home") assert resp.status_code == 200, resp.content assert "X-Outer" in resp.headers resp = client.get("/err") assert resp.status_code == 403, resp.content assert "X-Outer" in resp.headers resp = client.get("/mount/home") assert resp.status_code == 200, resp.content assert "X-Mounted" in resp.headers resp = client.get("/mount/err") assert resp.status_code == 403, resp.content assert "X-Mounted" in resp.headers def test_websocket_route_middleware( test_client_factory: TestClientFactory, ) -> None: async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() class WebsocketMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def modified_send(msg: Message) -> None: if msg["type"] == "websocket.accept": msg["headers"].append((b"X-Test", b"Set by middleware")) await send(msg) await self.app(scope, receive, modified_send) app = Starlette( routes=[ WebSocketRoute( "/ws", endpoint=websocket_endpoint, middleware=[Middleware(WebsocketMiddleware)], ) ] ) client = test_client_factory(app) with client.websocket_connect("/ws") as websocket: text = websocket.receive_text() assert text == "Hello, world!" assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")] def test_route_repr() -> None: route = Route("/welcome", endpoint=homepage) assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])" def test_route_repr_without_methods() -> None: route = Route("/welcome", endpoint=Endpoint, methods=None) assert repr(route) == "Route(path='/welcome', name='Endpoint', methods=[])" def test_websocket_route_repr() -> None: route = WebSocketRoute("/ws", endpoint=websocket_endpoint) assert repr(route) == "WebSocketRoute(path='/ws', name='websocket_endpoint')" def test_mount_repr() -> None: route = Mount( "/app", routes=[ Route("/", endpoint=homepage), ], ) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Mount(path='/app', name='', app=") def test_mount_named_repr() -> None: route = Mount( "/app", name="app", routes=[ Route("/", endpoint=homepage), ], ) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Mount(path='/app', name='app', app=") def test_host_repr() -> None: route = Host( "example.com", app=Router( [ Route("/", endpoint=homepage), ] ), ) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Host(host='example.com', name='', app=") def test_host_named_repr() -> None: route = Host( "example.com", name="app", app=Router( [ Route("/", endpoint=homepage), ] ), ) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Host(host='example.com', name='app', app=") async def echo_paths(request: Request, name: str) -> JSONResponse: return JSONResponse( { "name": name, "path": request.scope["path"], "root_path": request.scope["root_path"], } ) async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None: data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]} content = json.dumps(data).encode("utf-8") await send( { "type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")], } ) await send({"type": "http.response.body", "body": content}) echo_paths_routes = [ Route( "/path", functools.partial(echo_paths, name="path"), name="path", methods=["GET"], ), Route( "/root-queue/path", functools.partial(echo_paths, name="queue_path"), name="queue_path", methods=["POST"], ), Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")), Mount( "/sub", name="mount", routes=[ Route( "/path", functools.partial(echo_paths, name="subpath"), name="subpath", methods=["GET"], ), ], ), ] def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None: app = Starlette(routes=echo_paths_routes) client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root") response = client.get("/root/path") assert response.status_code == 200 assert response.json() == { "name": "path", "path": "/root/path", "root_path": "/root", } response = client.get("/root/asgipath/") assert response.status_code == 200 assert response.json() == { "name": "asgipath", "path": "/root/asgipath/", # Things that mount other ASGI apps, like WSGIMiddleware, would not be aware # of the prefixed path, and would have their own notion of their own paths, # so they need to be able to rely on the root_path to know the location they # are mounted on "root_path": "/root/asgipath", } response = client.get("/root/sub/path") assert response.status_code == 200 assert response.json() == { "name": "subpath", "path": "/root/sub/path", "root_path": "/root/sub", } response = client.post("/root/root-queue/path") assert response.status_code == 200 assert response.json() == { "name": "queue_path", "path": "/root/root-queue/path", "root_path": "/root", } ================================================ FILE: tests/test_schemas.py ================================================ from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.schemas import SchemaGenerator from starlette.websockets import WebSocket from tests.types import TestClientFactory schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}) def ws(session: WebSocket) -> None: """ws""" pass # pragma: no cover def get_user(request: Request) -> None: """ responses: 200: description: A user. examples: {"username": "tom"} """ pass # pragma: no cover def list_users(request: Request) -> None: """ responses: 200: description: A list of users. examples: [{"username": "tom"}, {"username": "lucy"}] """ pass # pragma: no cover def create_user(request: Request) -> None: """ responses: 200: description: A user. examples: {"username": "tom"} """ pass # pragma: no cover class OrganisationsEndpoint(HTTPEndpoint): def get(self, request: Request) -> None: """ responses: 200: description: A list of organisations. examples: [{"name": "Foo Corp."}, {"name": "Acme Ltd."}] """ pass # pragma: no cover def post(self, request: Request) -> None: """ responses: 200: description: An organisation. examples: {"name": "Foo Corp."} """ pass # pragma: no cover def regular_docstring_and_schema(request: Request) -> None: """ This a regular docstring example (not included in schema) --- responses: 200: description: This is included in the schema. """ pass # pragma: no cover def regular_docstring(request: Request) -> None: """ This a regular docstring example (not included in schema) """ pass # pragma: no cover def no_docstring(request: Request) -> None: pass # pragma: no cover def subapp_endpoint(request: Request) -> None: """ responses: 200: description: This endpoint is part of a subapp. """ pass # pragma: no cover def schema(request: Request) -> Response: return schemas.OpenAPIResponse(request=request) subapp = Starlette( routes=[ Route("/subapp-endpoint", endpoint=subapp_endpoint), ] ) app = Starlette( routes=[ WebSocketRoute("/ws", endpoint=ws), Route("/users/{id:int}", endpoint=get_user, methods=["GET"]), Route("/users", endpoint=list_users, methods=["GET", "HEAD"]), Route("/users", endpoint=create_user, methods=["POST"]), Route("/orgs", endpoint=OrganisationsEndpoint), Route("/regular-docstring-and-schema", endpoint=regular_docstring_and_schema), Route("/regular-docstring", endpoint=regular_docstring), Route("/no-docstring", endpoint=no_docstring), Route("/schema", endpoint=schema, methods=["GET"], include_in_schema=False), Mount("/subapp", subapp), Host("sub.domain.com", app=Router(routes=[Mount("/subapp2", subapp)])), ] ) def test_schema_generation() -> None: schema = schemas.get_schema(routes=app.routes) assert schema == { "openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}, "paths": { "/orgs": { "get": { "responses": { 200: { "description": "A list of organisations.", "examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}], } } }, "post": { "responses": { 200: { "description": "An organisation.", "examples": {"name": "Foo Corp."}, } } }, }, "/regular-docstring-and-schema": { "get": {"responses": {200: {"description": "This is included in the schema."}}} }, "/subapp/subapp-endpoint": { "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}} }, "/subapp2/subapp-endpoint": { "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}} }, "/users": { "get": { "responses": { 200: { "description": "A list of users.", "examples": [{"username": "tom"}, {"username": "lucy"}], } } }, "post": {"responses": {200: {"description": "A user.", "examples": {"username": "tom"}}}}, }, "/users/{id}": { "get": { "responses": { 200: { "description": "A user.", "examples": {"username": "tom"}, } } }, }, }, } EXPECTED_SCHEMA = """ info: title: Example API version: '1.0' openapi: 3.0.0 paths: /orgs: get: responses: 200: description: A list of organisations. examples: - name: Foo Corp. - name: Acme Ltd. post: responses: 200: description: An organisation. examples: name: Foo Corp. /regular-docstring-and-schema: get: responses: 200: description: This is included in the schema. /subapp/subapp-endpoint: get: responses: 200: description: This endpoint is part of a subapp. /subapp2/subapp-endpoint: get: responses: 200: description: This endpoint is part of a subapp. /users: get: responses: 200: description: A list of users. examples: - username: tom - username: lucy post: responses: 200: description: A user. examples: username: tom /users/{id}: get: responses: 200: description: A user. examples: username: tom """ def test_schema_endpoint(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.get("/schema") assert response.headers["Content-Type"] == "application/vnd.oai.openapi" assert response.text.strip() == EXPECTED_SCHEMA.strip() ================================================ FILE: tests/test_staticfiles.py ================================================ import os import stat import tempfile import time from pathlib import Path from typing import Any import anyio import pytest from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount from starlette.staticfiles import StaticFiles from tests.types import TestClientFactory def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "example.txt" with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmp_path) client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None: """ see https://github.com/Kludex/starlette/pull/935 """ path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("x" * 100) async def does_nothing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response: response = await call_next(request) return response routes = [Mount("/static", app=StaticFiles(directory=tmpdir), name="static")] middleware = [Middleware(BaseHTTPMiddleware, dispatch=does_nothing_middleware)] app = Starlette(routes=routes, middleware=middleware) client = test_client_factory(app) response = client.head("/static/example.txt") assert response.status_code == 200 assert response.headers.get("content-length") == "100" def test_staticfiles_with_package(test_client_factory: TestClientFactory) -> None: app = StaticFiles(packages=["tests"]) client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" app = StaticFiles(packages=[("tests", "statics")]) client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" def test_staticfiles_post(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) response = client.post("/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) response = client.get("/") assert response.status_code == 404 assert response.text == "Not Found" def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) response = client.get("/404.txt") assert response.status_code == 404 assert response.text == "Not Found" def test_staticfiles_instantiated_with_missing_directory(tmpdir: Path) -> None: with pytest.raises(RuntimeError) as exc_info: path = os.path.join(tmpdir, "no_such_directory") StaticFiles(directory=path) assert "does not exist" in str(exc_info.value) def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "no_such_directory") app = StaticFiles(directory=path, check_dir=False) client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "does not exist" in str(exc_info.value) def test_staticfiles_configured_with_file_instead_of_directory( tmpdir: Path, test_client_factory: TestClientFactory ) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=path, check_dir=False) client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "is not a directory" in str(exc_info.value) def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None: app = StaticFiles(directory=tmpdir) client = test_client_factory(app) assert not app.config_checked with pytest.raises(HTTPException): client.get("/") assert app.config_checked with pytest.raises(HTTPException): client.get("/") def test_staticfiles_prevents_breaking_out_of_directory(tmpdir: Path) -> None: directory = os.path.join(tmpdir, "foo") os.mkdir(directory) path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("outside root dir") app = StaticFiles(directory=directory) # We can't test this with 'httpx', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} with pytest.raises(HTTPException) as exc_info: anyio.run(app.get_response, path, scope) assert exc_info.value.status_code == 404 assert exc_info.value.detail == "Not Found" def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) client = test_client_factory(app) response = client.head("/example.txt") assert response.status_code == 200 assert response.content == b"" assert response.headers["content-length"] == "14" def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 last_etag = first_resp.headers["etag"] second_resp = client.get("/example.txt", headers={"if-none-match": last_etag}) assert second_resp.status_code == 304 assert second_resp.content == b"" second_resp = client.get("/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'}) assert second_resp.status_code == 304 assert second_resp.content == b"" def test_staticfiles_200_with_etag_mismatch(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 assert first_resp.headers["etag"] != '"123"' second_resp = client.get("/example.txt", headers={"if-none-match": '"123"'}) assert second_resp.status_code == 200 assert second_resp.content == b"" def test_staticfiles_200_with_etag_mismatch_and_timestamp_match( tmpdir: Path, test_client_factory: TestClientFactory ) -> None: path = tmpdir / "example.txt" path.write_text("", encoding="utf-8") app = StaticFiles(directory=tmpdir) client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 assert first_resp.headers["etag"] != '"123"' last_modified = first_resp.headers["last-modified"] # If `if-none-match` is present, `if-modified-since` is ignored. second_resp = client.get("/example.txt", headers={"if-none-match": '"123"', "if-modified-since": last_modified}) assert second_resp.status_code == 200 assert second_resp.content == b"" def test_staticfiles_304_with_last_modified_compare_last_req( tmpdir: Path, test_client_factory: TestClientFactory ) -> None: path = os.path.join(tmpdir, "example.txt") file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")) with open(path, "w") as file: file.write("") os.utime(path, (file_last_modified_time, file_last_modified_time)) app = StaticFiles(directory=tmpdir) client = test_client_factory(app) # last modified less than last request, 304 response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}) assert response.status_code == 304 assert response.content == b"" # last modified greater than last request, 200 with content response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"}) assert response.status_code == 200 assert response.content == b"" def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

    Custom not found page

    ") path = os.path.join(tmpdir, "dir") os.mkdir(path) path = os.path.join(path, "index.html") with open(path, "w") as file: file.write("

    Hello

    ") app = StaticFiles(directory=tmpdir, html=True) client = test_client_factory(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" assert response.status_code == 200 assert response.text == "

    Hello

    " response = client.get("/dir") assert response.url == "http://testserver/dir/" assert response.status_code == 200 assert response.text == "

    Hello

    " response = client.get("/dir/index.html") assert response.url == "http://testserver/dir/index.html" assert response.status_code == 200 assert response.text == "

    Hello

    " response = client.get("/missing") assert response.status_code == 404 assert response.text == "

    Custom not found page

    " def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

    Custom not found page

    ") path = os.path.join(tmpdir, "dir") os.mkdir(path) app = StaticFiles(directory=tmpdir, html=True) client = test_client_factory(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" assert response.status_code == 404 assert response.text == "

    Custom not found page

    " response = client.get("/dir") assert response.url == "http://testserver/dir" assert response.status_code == 404 assert response.text == "

    Custom not found page

    " response = client.get("/missing") assert response.status_code == 404 assert response.text == "

    Custom not found page

    " def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "dir") os.mkdir(path) path = os.path.join(path, "index.html") with open(path, "w") as file: file.write("

    Hello

    ") app = StaticFiles(directory=tmpdir, html=True) client = test_client_factory(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" assert response.status_code == 200 assert response.text == "

    Hello

    " response = client.get("/dir") assert response.url == "http://testserver/dir/" assert response.status_code == 200 assert response.text == "

    Hello

    " with pytest.raises(HTTPException) as exc_info: response = client.get("/missing") assert exc_info.value.status_code == 404 def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "hello.html") with open(path, "w") as file: file.write("

    Hello

    ") app = StaticFiles(directory=tmpdir, html=True) client = test_client_factory(app) with pytest.raises(HTTPException) as exc_info: response = client.get("/") assert exc_info.value.status_code == 404 response = client.get("/hello.html") assert response.status_code == 200 assert response.text == "

    Hello

    " def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( tmpdir: Path, test_client_factory: TestClientFactory ) -> None: path_404 = os.path.join(tmpdir, "404.html") with open(path_404, "w") as file: file.write("

    404 file

    ") path_some = os.path.join(tmpdir, "some.html") with open(path_some, "w") as file: file.write("

    some file

    ") common_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")) os.utime(path_404, (common_modified_time, common_modified_time)) os.utime(path_some, (common_modified_time, common_modified_time)) app = StaticFiles(directory=tmpdir, html=True) client = test_client_factory(app) resp_exists = client.get("/some.html") assert resp_exists.status_code == 200 assert resp_exists.text == "

    some file

    " resp_cached = client.get( "/some.html", headers={"If-Modified-Since": resp_exists.headers["last-modified"]}, ) assert resp_cached.status_code == 304 os.remove(path_some) resp_deleted = client.get( "/some.html", headers={"If-Modified-Since": resp_exists.headers["last-modified"]}, ) assert resp_deleted.status_code == 404 assert resp_deleted.text == "

    404 file

    " def test_staticfiles_with_invalid_dir_permissions_returns_401( tmp_path: Path, test_client_factory: TestClientFactory ) -> None: (tmp_path / "example.txt").write_bytes(b"") original_mode = tmp_path.stat().st_mode tmp_path.chmod(stat.S_IRWXO) try: routes = [ Mount( "/", app=StaticFiles(directory=os.fsdecode(tmp_path)), name="static", ) ] app = Starlette(routes=routes) client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 401 assert response.text == "Unauthorized" finally: tmp_path.chmod(original_mode) def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) response = client.get("/foo/example.txt") assert response.status_code == 404 assert response.text == "Not Found" def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) response = client.get("/example.txt/foo") assert response.status_code == 404 assert response.text == "Not Found" def test_staticfiles_null_byte_in_path(tmpdir: Path, test_client_factory: TestClientFactory) -> None: routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) response = client.get("/example%00.txt") assert response.status_code == 404 def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestClientFactory) -> None: routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app) path_max_size = os.pathconf("/", "PC_PATH_MAX") response = client.get(f"/{'a' * path_max_size}.txt") assert response.status_code == 404 assert response.text == "Not Found" def test_staticfiles_unhandled_os_error_returns_500( tmpdir: Path, test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch, ) -> None: def mock_timeout(*args: Any, **kwargs: Any) -> None: raise TimeoutError path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")] app = Starlette(routes=routes) client = test_client_factory(app, raise_server_exceptions=False) monkeypatch.setattr("starlette.staticfiles.StaticFiles.lookup_path", mock_timeout) response = client.get("/example.txt") assert response.status_code == 500 assert response.text == "Internal Server Error" def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None: statics_path = os.path.join(tmpdir, "statics") os.mkdir(statics_path) source_path = tempfile.mkdtemp() source_file_path = os.path.join(source_path, "page.html") with open(source_file_path, "w") as file: file.write("

    Hello

    ") statics_file_path = os.path.join(statics_path, "index.html") os.symlink(source_file_path, statics_file_path) app = StaticFiles(directory=statics_path, follow_symlink=True) client = test_client_factory(app) response = client.get("/index.html") assert response.url == "http://testserver/index.html" assert response.status_code == 200 assert response.text == "

    Hello

    " def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_factory: TestClientFactory) -> None: statics_path = os.path.join(tmpdir, "statics") statics_html_path = os.path.join(statics_path, "html") os.mkdir(statics_path) source_path = tempfile.mkdtemp() source_file_path = os.path.join(source_path, "page.html") with open(source_file_path, "w") as file: file.write("

    Hello

    ") os.symlink(source_path, statics_html_path) app = StaticFiles(directory=statics_path, follow_symlink=True) client = test_client_factory(app) response = client.get("/html/page.html") assert response.url == "http://testserver/html/page.html" assert response.status_code == 200 assert response.text == "

    Hello

    " def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir: Path) -> None: statics_path = os.path.join(tmpdir, "statics") root_source_path = tempfile.mkdtemp() source_path = os.path.join(root_source_path, "statics") os.mkdir(source_path) source_file_path = os.path.join(root_source_path, "index.html") with open(source_file_path, "w") as file: file.write("

    Hello

    ") os.symlink(source_path, statics_path) app = StaticFiles(directory=statics_path, follow_symlink=True) # We can't test this with 'httpx', so we test the app directly here. path = app.get_path({"path": "/../index.html"}) scope = {"method": "GET"} with pytest.raises(HTTPException) as exc_info: anyio.run(app.get_response, path, scope) assert exc_info.value.status_code == 404 assert exc_info.value.detail == "Not Found" def test_staticfiles_avoids_path_traversal(tmp_path: Path) -> None: statics_path = tmp_path / "static" statics_disallow_path = tmp_path / "static_disallow" statics_path.mkdir() statics_disallow_path.mkdir() static_index_file = statics_path / "index.html" statics_disallow_path_index_file = statics_disallow_path / "index.html" static_file = tmp_path / "static1.txt" static_index_file.write_text("

    Hello

    ") statics_disallow_path_index_file.write_text("

    Private

    ") static_file.write_text("Private") app = StaticFiles(directory=statics_path) # We can't test this with 'httpx', so we test the app directly here. path = app.get_path({"path": "/../static1.txt"}) with pytest.raises(HTTPException) as exc_info: anyio.run(app.get_response, path, {"method": "GET"}) assert exc_info.value.status_code == 404 assert exc_info.value.detail == "Not Found" path = app.get_path({"path": "/../static_disallow/index.html"}) with pytest.raises(HTTPException) as exc_info: anyio.run(app.get_response, path, {"method": "GET"}) assert exc_info.value.status_code == 404 assert exc_info.value.detail == "Not Found" def test_staticfiles_self_symlinks(tmp_path: Path, test_client_factory: TestClientFactory) -> None: statics_path = tmp_path / "statics" statics_path.mkdir() source_file_path = statics_path / "index.html" source_file_path.write_text("

    Hello

    ", encoding="utf-8") statics_symlink_path = tmp_path / "statics_symlink" statics_symlink_path.symlink_to(statics_path) app = StaticFiles(directory=statics_symlink_path, follow_symlink=True) client = test_client_factory(app) response = client.get("/index.html") assert response.url == "http://testserver/index.html" assert response.status_code == 200 assert response.text == "

    Hello

    " def test_staticfiles_relative_directory_symlinks(test_client_factory: TestClientFactory) -> None: app = StaticFiles(directory="tests/statics", follow_symlink=True) client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" ================================================ FILE: tests/test_status.py ================================================ import importlib import pytest @pytest.mark.parametrize( "constant,msg", ( ( "HTTP_413_REQUEST_ENTITY_TOO_LARGE", "'HTTP_413_REQUEST_ENTITY_TOO_LARGE' is deprecated. Use 'HTTP_413_CONTENT_TOO_LARGE' instead.", ), ( "HTTP_414_REQUEST_URI_TOO_LONG", "'HTTP_414_REQUEST_URI_TOO_LONG' is deprecated. Use 'HTTP_414_URI_TOO_LONG' instead.", ), ), ) def test_deprecated_types(constant: str, msg: str) -> None: with pytest.warns(DeprecationWarning) as record: getattr(importlib.import_module("starlette.status"), constant) assert len(record) == 1 assert msg in str(record.list[0]) def test_unknown_status() -> None: with pytest.raises( AttributeError, match="module 'starlette.status' has no attribute 'HTTP_999_UNKNOWN_STATUS_CODE'", ): getattr(importlib.import_module("starlette.status"), "HTTP_999_UNKNOWN_STATUS_CODE") ================================================ FILE: tests/test_templates.py ================================================ from __future__ import annotations import os from pathlib import Path import jinja2 import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route from starlette.templating import Jinja2Templates from tests.types import TestClientFactory def test_templates(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") app = Starlette(debug=True, routes=[Route("/", endpoint=homepage)]) templates = Jinja2Templates(directory=str(tmpdir)) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore def test_templates_autoescape(tmp_path: Path) -> None: path = tmp_path / "index.html" path.write_text("Hello, {{ name }}") templates = Jinja2Templates(directory=tmp_path) template = templates.get_template("index.html") assert ( template.render(name="") == "Hello, <script>alert('XSS')</script>" ) def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "index.html" path.write_text("Hello {{ username }}") async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") def hello_world_processor(request: Request) -> dict[str, str]: return {"username": "World"} app = Starlette( debug=True, routes=[Route("/", endpoint=homepage)], ) templates = Jinja2Templates( directory=tmp_path, context_processors=[ hello_world_processor, ], ) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello World" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request", "username"} # type: ignore def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: return await call_next(request) app = Starlette( debug=True, routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CustomMiddleware)], ) templates = Jinja2Templates(directory=str(tmpdir)) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore def test_templates_with_directories(tmp_path: Path, test_client_factory: TestClientFactory) -> None: dir_a = tmp_path.resolve() / "a" dir_a.mkdir() template_a = dir_a / "template_a.html" template_a.write_text(" a") async def page_a(request: Request) -> Response: return templates.TemplateResponse(request, "template_a.html") dir_b = tmp_path.resolve() / "b" dir_b.mkdir() template_b = dir_b / "template_b.html" template_b.write_text(" b") async def page_b(request: Request) -> Response: return templates.TemplateResponse(request, "template_b.html") app = Starlette( debug=True, routes=[Route("/a", endpoint=page_a), Route("/b", endpoint=page_b)], ) templates = Jinja2Templates(directory=[dir_a, dir_b]) client = test_client_factory(app) response = client.get("/a") assert response.text == " a" assert response.template.name == "template_a.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore response = client.get("/b") assert response.text == " b" assert response.template.name == "template_b.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore def test_templates_require_directory_or_environment() -> None: with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"): Jinja2Templates() # type: ignore[call-overload] def test_templates_require_directory_or_environment_not_both() -> None: with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"): Jinja2Templates(directory="dir", env=jinja2.Environment()) # type: ignore[call-overload] def test_templates_with_directory(tmpdir: Path) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello") templates = Jinja2Templates(directory=str(tmpdir)) template = templates.get_template("index.html") assert template.render({}) == "Hello" def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") env = jinja2.Environment(loader=jinja2.FileSystemLoader(str(tmpdir))) app = Starlette( debug=True, routes=[Route("/", endpoint=homepage)], ) templates = Jinja2Templates(env=env) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" # type: ignore assert set(response.context.keys()) == {"request"} # type: ignore ================================================ FILE: tests/test_testclient.py ================================================ from __future__ import annotations import itertools import sys from asyncio import Task, current_task as asyncio_current_task from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any import anyio import anyio.lowlevel import pytest import sniffio import trio.lowlevel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, RedirectResponse, Response from starlette.routing import Route from starlette.testclient import ASGIInstance, TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect from tests.types import TestClientFactory def mock_service_endpoint(request: Request) -> JSONResponse: return JSONResponse({"mock": "example"}) mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)]) def current_task() -> Task[Any] | trio.lowlevel.Task: # anyio's TaskInfo comparisons are invalid after their associated native # task object is GC'd https://github.com/agronholm/anyio/issues/324 asynclib_name = sniffio.current_async_library() if asynclib_name == "trio": return trio.lowlevel.current_task() if asynclib_name == "asyncio": task = asyncio_current_task() if task is None: raise RuntimeError("must be called from a running task") # pragma: no cover return task raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> None: """ We should be able to use the test client within applications. This is useful if we need to mock out other services, during tests or in development. """ def homepage(request: Request) -> JSONResponse: client = test_client_factory(mock_service) response = client.get("/") return JSONResponse(response.json()) app = Starlette(routes=[Route("/", endpoint=homepage)]) client = test_client_factory(app) response = client.get("/") assert response.json() == {"mock": "example"} def test_testclient_headers_behavior() -> None: """ We should be able to use the test client with user defined headers. This is useful if we need to set custom headers for authentication during tests or in development. """ client = TestClient(mock_service) assert client.headers.get("user-agent") == "testclient" client = TestClient(mock_service, headers={"user-agent": "non-default-agent"}) assert client.headers.get("user-agent") == "non-default-agent" client = TestClient(mock_service, headers={"Authentication": "Bearer 123"}) assert client.headers.get("user-agent") == "testclient" assert client.headers.get("Authentication") == "Bearer 123" def test_use_testclient_as_contextmanager(test_client_factory: TestClientFactory, anyio_backend_name: str) -> None: """ This test asserts a number of properties that are important for an app level task_group """ counter = itertools.count() identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar") def get_identity() -> int: try: return identity_runvar.get() except LookupError: token = next(counter) identity_runvar.set(token) return token startup_task = object() startup_loop = None shutdown_task = object() shutdown_loop = None @asynccontextmanager async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]: nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop startup_task = current_task() startup_loop = get_identity() async with anyio.create_task_group(): yield shutdown_task = current_task() shutdown_loop = get_identity() async def loop_id(request: Request) -> JSONResponse: return JSONResponse(get_identity()) app = Starlette( lifespan=lifespan_context, routes=[Route("/loop_id", endpoint=loop_id)], ) client = test_client_factory(app) with client: # within a TestClient context every async request runs in the same thread assert client.get("/loop_id").json() == 0 assert client.get("/loop_id").json() == 0 # that thread is also the same as the lifespan thread assert startup_loop == 0 assert shutdown_loop == 0 # lifespan events run in the same task, this is important because a task # group must be entered and exited in the same task. assert startup_task is shutdown_task # outside the TestClient context, new requests continue to spawn in new # event loops in new threads assert client.get("/loop_id").json() == 1 assert client.get("/loop_id").json() == 2 first_task = startup_task with client: # the TestClient context can be re-used, starting a new lifespan task # in a new thread assert client.get("/loop_id").json() == 3 assert client.get("/loop_id").json() == 3 assert startup_loop == 3 assert shutdown_loop == 3 # lifespan events still run in the same task, with the context but... assert startup_task is shutdown_task # ... the second TestClient context creates a new lifespan task. assert first_task is not startup_task def test_error_on_startup(test_client_factory: TestClientFactory) -> None: @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: raise RuntimeError("Startup error") yield startup_error_app = Starlette(lifespan=lifespan) with pytest.raises(RuntimeError, match="Startup error"): with test_client_factory(startup_error_app): pass # pragma: no cover def test_exception_in_middleware(test_client_factory: TestClientFactory) -> None: class MiddlewareException(Exception): pass class BrokenMiddleware: def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: raise MiddlewareException() broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) with pytest.raises(MiddlewareException): with test_client_factory(broken_middleware): pass # pragma: no cover def test_testclient_asgi2(test_client_factory: TestClientFactory) -> None: def app(scope: Scope) -> ASGIInstance: async def inner(receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": 200, "headers": [[b"content-type", b"text/plain"]], } ) await send({"type": "http.response.body", "body": b"Hello, world!"}) return inner client = test_client_factory(app) # type: ignore response = client.get("/") assert response.text == "Hello, world!" def test_testclient_asgi3(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": 200, "headers": [[b"content-type", b"text/plain"]], } ) await send({"type": "http.response.body", "body": b"Hello, world!"}) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> None: def app(scope: Scope) -> ASGIInstance: async def respond(websocket: WebSocket) -> None: await websocket.send_json({"message": "test"}) async def asgi(receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async with anyio.create_task_group() as task_group: task_group.start_soon(respond, websocket) try: # this will block as the client does not send us data # it should not prevent `respond` from executing though await websocket.receive_json() except WebSocketDisconnect: pass return asgi client = test_client_factory(app) # type: ignore with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None: cancelled = False def app(scope: Scope) -> ASGIInstance: async def asgi(receive: Receive, send: Send) -> None: nonlocal cancelled try: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await anyio.sleep_forever() except anyio.get_cancelled_exc_class(): cancelled = True raise return asgi client = test_client_factory(app) # type: ignore with client.websocket_connect("/"): ... assert cancelled def test_client(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: client = scope.get("client") assert client is not None host, port = client response = JSONResponse({"host": host, "port": port}) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() == {"host": "testclient", "port": 50000} def test_client_custom_client(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: client = scope.get("client") assert client is not None host, port = client response = JSONResponse({"host": host, "port": port}) await response(scope, receive, send) client = test_client_factory(app, client=("192.168.0.1", 3000)) response = client.get("/") assert response.json() == {"host": "192.168.0.1", "port": 3000} @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) def test_query_params(test_client_factory: TestClientFactory, param: str) -> None: def homepage(request: Request) -> Response: return Response(request.query_params["param"]) app = Starlette(routes=[Route("/", endpoint=homepage)]) client = test_client_factory(app) response = client.get("/", params={"param": param}) assert response.text == param @pytest.mark.parametrize( "domain, ok", [ pytest.param( "testserver", True, marks=[ pytest.mark.xfail( sys.version_info < (3, 11), reason="Fails due to domain handling in http.cookiejar module (see #2152)", ), ], ), ("testserver.local", True), ("localhost", False), ("example.com", False), ], ) def test_domain_restricted_cookies(test_client_factory: TestClientFactory, domain: str, ok: bool) -> None: """ Test that test client discards domain restricted cookies which do not match the base_url of the testclient (`http://testserver` by default). The domain `testserver.local` works because the Python http.cookiejar module derives the "effective domain" by appending `.local` to non-dotted request domains in accordance with RFC 2965. """ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie( "mycookie", "myvalue", path="/", domain=domain, ) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") cookie_set = len(response.cookies) == 1 assert cookie_set == ok def test_forward_follow_redirects(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if "/ok" in scope["path"]: response = Response("ok") else: response = RedirectResponse("/ok") await response(scope, receive, send) client = test_client_factory(app, follow_redirects=True) response = client.get("/") assert response.status_code == 200 def test_forward_nofollow_redirects(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/ok") await response(scope, receive, send) client = test_client_factory(app, follow_redirects=False) response = client.get("/") assert response.status_code == 307 def test_with_duplicate_headers(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> JSONResponse: return JSONResponse({"x-token": request.headers.getlist("x-token")}) app = Starlette(routes=[Route("/", endpoint=homepage)]) client = test_client_factory(app) response = client.get("/", headers=[("x-token", "foo"), ("x-token", "bar")]) assert response.json() == {"x-token": ["foo", "bar"]} def test_merge_url(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> Response: return Response(request.url.path) app = Starlette(routes=[Route("/api/v1/bar", endpoint=homepage)]) client = test_client_factory(app, base_url="http://testserver/api/v1/") response = client.get("/bar") assert response.text == "/api/v1/bar" def test_raw_path_with_querystring(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(scope.get("raw_path")) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/hello-world", params={"foo": "bar"}) assert response.content == b"/hello-world" def test_websocket_raw_path_without_params(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() raw_path = scope.get("raw_path") assert raw_path is not None await websocket.send_bytes(raw_path) client = test_client_factory(app) with client.websocket_connect("/hello-world", params={"foo": "bar"}) as websocket: data = websocket.receive_bytes() assert data == b"/hello-world" def test_timeout_deprecation() -> None: with pytest.deprecated_call(match="You should not use the 'timeout' argument with the TestClient."): client = TestClient(mock_service) client.get("/", timeout=1) ================================================ FILE: tests/test_websockets.py ================================================ import sys from collections.abc import AsyncGenerator, MutableMapping from pathlib import Path from typing import Any import anyio import pytest from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette import status from starlette.responses import FileResponse, Response, StreamingResponse from starlette.testclient import WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState from tests.types import TestClientFactory def test_websocket_url(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() message = await websocket.receive_json(mode="binary") await websocket.send_json(message, mode="binary") await websocket.close() client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "data"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"test": "data"} def test_websocket_ensure_unicode_on_send_json( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() message = await websocket.receive_json(mode="text") await websocket.send_json(message, mode="text") await websocket.close() client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "数据"}, mode="text") data = websocket.receive_text() assert data == '{"test":"数据"}' def test_websocket_query_params(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) query_params = dict(websocket.query_params) await websocket.accept() await websocket.send_json({"params": query_params}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/?a=abc&b=456") as websocket: data = websocket.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} @pytest.mark.skipif( any(module in sys.modules for module in ("brotli", "brotlicffi")), reason='urllib3 includes "br" to the "accept-encoding" headers.', ) def test_websocket_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) headers = dict(websocket.headers) await websocket.accept() await websocket.send_json({"headers": headers}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", "accept-encoding": "gzip, deflate", "connection": "upgrade", "host": "testserver", "user-agent": "testclient", "sec-websocket-key": "testserver==", "sec-websocket-version": "13", } data = websocket.receive_json() assert data == {"headers": expected_headers} def test_websocket_port(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"port": websocket.url.port}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"port": 123} def test_websocket_send_and_receive_text( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_text() await websocket.send_text("Message was: " + data) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" def test_websocket_send_and_receive_bytes( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_bytes() await websocket.send_bytes(b"Message was: " + data) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" def test_websocket_send_and_receive_json( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_json() await websocket.send_json({"message": data}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_text(): await websocket.send_text("Message was: " + data) client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_bytes(): await websocket.send_bytes(b"Message was: " + data) client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_json(): await websocket.send_json({"message": data}) client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None: stream_send: ObjectSendStream[MutableMapping[str, Any]] stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] stream_send, stream_receive = anyio.create_memory_object_stream() async def reader(websocket: WebSocket) -> None: async with stream_send: async for data in websocket.iter_json(): await stream_send.send(data) async def writer(websocket: WebSocket) -> None: async with stream_receive: async for message in stream_receive: await websocket.send_json(message) async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async with anyio.create_task_group() as task_group: task_group.start_soon(reader, websocket) await writer(websocket) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"hello": "world"} def test_client_close(test_client_factory: TestClientFactory) -> None: close_code = None close_reason = None async def app(scope: Scope, receive: Receive, send: Send) -> None: nonlocal close_code, close_reason websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() try: await websocket.receive_text() except WebSocketDisconnect as exc: close_code = exc.code close_reason = exc.reason client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") assert close_code == status.WS_1001_GOING_AWAY assert close_reason == "Going Away" @pytest.mark.anyio async def test_client_disconnect_on_send() -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_text("Hello, world!") async def receive() -> Message: return {"type": "websocket.connect"} async def send(message: Message) -> None: if message["type"] == "websocket.accept": return # Simulate the exception the server would send to the application when the client disconnects. raise OSError with pytest.raises(WebSocketDisconnect) as ctx: await app({"type": "websocket", "path": "/"}, receive, send) assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE def test_application_close(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close(status.WS_1001_GOING_AWAY) client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY def test_rejected_connection(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} await websocket.close(status.WS_1001_GOING_AWAY) client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.code == status.WS_1001_GOING_AWAY def test_send_denial_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} response = Response(status_code=404, content="foo") await websocket.send_denial_response(response) client = test_client_factory(app) with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.status_code == 404 assert exc.value.content == b"foo" def test_send_denial_response_with_streaming_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) message = await websocket.receive() assert message == {"type": "websocket.connect"} async def content() -> AsyncGenerator[bytes]: yield b"hello" yield b"world" await websocket.send_denial_response(StreamingResponse(content(), status_code=403)) client = test_client_factory(app) with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/"): ... # pragma: no cover assert exc.value.status_code == 403 assert exc.value.content == b"helloworld" def test_send_denial_response_with_file_response(test_client_factory: TestClientFactory, tmp_path: Path) -> None: file_path = tmp_path / "denial.txt" file_path.write_text("test content") async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} await websocket.send_denial_response(FileResponse(file_path, status_code=401)) client = test_client_factory(app) with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.status_code == 401 assert exc.value.content == b"test content" def test_send_response_multi(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} await websocket.send( { "type": "websocket.http.response.start", "status": 404, "headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")], } ) await websocket.send({"type": "websocket.http.response.body", "body": b"hard", "more_body": True}) await websocket.send({"type": "websocket.http.response.body", "body": b"body"}) client = test_client_factory(app) with pytest.raises(WebSocketDenialResponse) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.status_code == 404 assert exc.value.content == b"hardbody" assert exc.value.headers["foo"] == "bar" def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: del scope["extensions"]["websocket.http.response"] websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} response = Response(status_code=404, content="foo") with pytest.raises( RuntimeError, match="The server doesn't support the Websocket Denial Response extension.", ): await websocket.send_denial_response(response) await websocket.close() client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: no cover assert exc.value.code == status.WS_1000_NORMAL_CLOSURE def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() assert msg == {"type": "websocket.connect"} response = Response(status_code=404, content="foo") await websocket.send( { "type": "websocket.http.response.start", "status": response.status_code, "headers": response.raw_headers, } ) await websocket.send( { "type": "websocket.http.response.start", "status": response.status_code, "headers": response.raw_headers, } ) client = test_client_factory(app) with pytest.raises( RuntimeError, match=("Expected ASGI message \"websocket.http.response.body\", but got 'websocket.http.response.start'"), ): with client.websocket_connect("/"): pass # pragma: no cover def test_subprotocol(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") await websocket.close() client = test_client_factory(app) with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" def test_additional_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept(headers=[(b"additional", b"header")]) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: assert websocket.extra_headers == [(b"additional", b"header")] def test_no_additional_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: assert websocket.extra_headers == [] def test_websocket_exception(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: assert False client = test_client_factory(app) with pytest.raises(AssertionError): with client.websocket_connect("/123?a=abc"): pass # pragma: no cover def test_duplicate_close(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close() await websocket.close() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() message = await websocket.receive() assert message["type"] == "websocket.disconnect" message = await websocket.receive() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.close() def test_websocket_scope_interface() -> None: """ A WebSocket can be instantiated with a scope, and presents a `Mapping` interface. """ async def mock_receive() -> Message: # type: ignore ... # pragma: no cover async def mock_send(message: Message) -> None: ... # pragma: no cover websocket = WebSocket({"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send) assert websocket["type"] == "websocket" assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []} assert len(websocket) == 3 # check __eq__ and __hash__ assert websocket != WebSocket( {"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send, ) assert websocket == websocket assert websocket in {websocket} assert {websocket} == {websocket} def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY assert exc.value.reason == "Going Away" def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({}, mode="invalid") client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive_json(mode="invalid") client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_text() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_bytes() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_json() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_send_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.send"}) client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.accept"}) await websocket.send({"type": "websocket.accept"}) client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: no cover def test_receive_before_accept(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() websocket.client_state = WebSocketState.CONNECTING await websocket.receive() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.send({"type": "websocket.send"}) def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.send({"type": "websocket.connect"}) ================================================ FILE: tests/types.py ================================================ from __future__ import annotations from typing import TYPE_CHECKING, Protocol import httpx from starlette.testclient import TestClient from starlette.types import ASGIApp if TYPE_CHECKING: class TestClientFactory(Protocol): # pragma: no cover def __call__( self, app: ASGIApp, base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", cookies: httpx._types.CookieTypes | None = None, headers: dict[str, str] | None = None, follow_redirects: bool = True, client: tuple[str, int] = ("testclient", 50000), ) -> TestClient: ... else: # pragma: no cover class TestClientFactory: __test__ = False